import gradio as gr import torch import json import time from transformers import AutoTokenizer, AutoModelForSequenceClassification from fastapi import FastAPI from pydantic import BaseModel # Label names in order LABEL_NAMES = [ "L01_DISEASE_CLAIM", "L02_STRUCTURE_FUNCTION", "L03_UNSUBSTANTIATED_EFFICACY", "L04_DRUG_COMPARISON", "L05_DOSING_INSTRUCTIONS", "L06_IMPLIED_HUMAN_USE", "L07_PHARMA_GRADE_CLAIM", "L08_TESTIMONIAL_VIOLATION", "L09_MISSING_DISCLAIMER", "L10_MISLEADING_DISCLAIMER", "L11_MISSING_COA", "L12_FDA_FALSE_CLAIM", "L13_MISSING_DISCLOSURES", "L14_FAIR_BALANCE_VIOLATION", "L19_DECEPTIVE_PRICING", "L21_CONSUMER_LANGUAGE", "L22_BODYBUILDING_CONTENT", "L23_META_HEALTH_KEYWORDS", "L24_COVER_CONTRADICTION", "L25_MISSING_RESEARCH_ID", "L27_THCA_LEGALITY_CLAIM", "L28_CBD_DISEASE_CLAIM", "L31_COMPLIANT_RESEARCH", "L32_COMPLIANT_DISCLAIMER", "L33_COMPLIANT_MARKETING" ] # Optimized per-label thresholds from training THRESHOLDS = { "L01_DISEASE_CLAIM": 0.35, "L02_STRUCTURE_FUNCTION": 0.55, "L03_UNSUBSTANTIATED_EFFICACY": 0.15, "L04_DRUG_COMPARISON": 0.125, "L05_DOSING_INSTRUCTIONS": 0.05, "L06_IMPLIED_HUMAN_USE": 0.275, "L07_PHARMA_GRADE_CLAIM": 0.375, "L08_TESTIMONIAL_VIOLATION": 0.25, "L09_MISSING_DISCLAIMER": 0.15, "L10_MISLEADING_DISCLAIMER": 0.5, "L11_MISSING_COA": 0.65, "L12_FDA_FALSE_CLAIM": 0.05, "L13_MISSING_DISCLOSURES": 0.075, "L14_FAIR_BALANCE_VIOLATION": 0.325, "L19_DECEPTIVE_PRICING": 0.05, "L21_CONSUMER_LANGUAGE": 0.075, "L22_BODYBUILDING_CONTENT": 0.075, "L23_META_HEALTH_KEYWORDS": 0.05, "L24_COVER_CONTRADICTION": 0.15, "L25_MISSING_RESEARCH_ID": 0.425, "L27_THCA_LEGALITY_CLAIM": 0.15, "L28_CBD_DISEASE_CLAIM": 0.175, "L31_COMPLIANT_RESEARCH": 0.4, "L32_COMPLIANT_DISCLAIMER": 0.3, "L33_COMPLIANT_MARKETING": 0.05 } MODEL_ID = "Austin-Groundsetter/deberta-prism-v2" print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) model.eval() print("Model loaded!") def predict(text: str) -> dict: """Classify text and return label predictions with confidence scores.""" start = time.time() inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True) with torch.no_grad(): outputs = model(**inputs) probs = torch.sigmoid(outputs.logits).squeeze().tolist() latency_ms = round((time.time() - start) * 1000, 1) triggered = {} all_scores = {} for i, label in enumerate(LABEL_NAMES): score = round(probs[i], 4) threshold = THRESHOLDS[label] all_scores[label] = {"score": score, "threshold": threshold, "triggered": score >= threshold} if score >= threshold: triggered[label] = score return { "triggered_labels": triggered, "all_scores": all_scores, "latency_ms": latency_ms, "num_triggered": len(triggered) } def classify_text(text): """Gradio UI wrapper.""" if not text or not text.strip(): return "Please enter text to classify." result = predict(text) output_lines = [f"Latency: {result['latency_ms']}ms", f"Labels triggered: {result['num_triggered']}", ""] if result['triggered_labels']: output_lines.append("TRIGGERED LABELS:") for label, score in sorted(result['triggered_labels'].items(), key=lambda x: -x[1]): output_lines.append(f" {label}: {score:.4f} (threshold: {THRESHOLDS[label]})") else: output_lines.append("NO LABELS TRIGGERED (text appears compliant)") return "\n".join(output_lines) # Gradio UI with gr.Blocks(title="PRISM DeBERTa Classifier") as demo: gr.Markdown("# PRISM DeBERTa v2 — Compliance Classifier") gr.Markdown("25-label multi-label classifier for supplement/hemp/peptide compliance scanning.") with gr.Row(): text_input = gr.Textbox(label="Text to classify", lines=5, placeholder="Enter product page text...") output = gr.Textbox(label="Results", lines=12) btn = gr.Button("Classify", variant="primary") btn.click(fn=classify_text, inputs=text_input, outputs=output) # Mount custom FastAPI REST endpoint for scanner integration app = FastAPI() class PredictRequest(BaseModel): text: str @app.post("/api/classify") async def api_classify(req: PredictRequest): """REST API endpoint for PRISM scanner. POST {"text": "..."} -> predictions.""" return predict(req.text) @app.get("/api/health") async def health(): return {"status": "ok", "model": MODEL_ID, "labels": len(LABEL_NAMES)} # Mount Gradio app on FastAPI app = gr.mount_gradio_app(app, demo, path="/") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)