| import gradio as gr |
| import torch |
| import json |
| import time |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| from fastapi import FastAPI |
| from pydantic import BaseModel |
|
|
| |
| LABEL_NAMES = [ |
| "L01_DISEASE_CLAIM", "L02_STRUCTURE_FUNCTION", "L03_UNSUBSTANTIATED_EFFICACY", |
| "L04_DRUG_COMPARISON", "L05_DOSING_INSTRUCTIONS", "L06_IMPLIED_HUMAN_USE", |
| "L07_PHARMA_GRADE_CLAIM", "L08_TESTIMONIAL_VIOLATION", "L09_MISSING_DISCLAIMER", |
| "L10_MISLEADING_DISCLAIMER", "L11_MISSING_COA", "L12_FDA_FALSE_CLAIM", |
| "L13_MISSING_DISCLOSURES", "L14_FAIR_BALANCE_VIOLATION", "L19_DECEPTIVE_PRICING", |
| "L21_CONSUMER_LANGUAGE", "L22_BODYBUILDING_CONTENT", "L23_META_HEALTH_KEYWORDS", |
| "L24_COVER_CONTRADICTION", "L25_MISSING_RESEARCH_ID", "L27_THCA_LEGALITY_CLAIM", |
| "L28_CBD_DISEASE_CLAIM", "L31_COMPLIANT_RESEARCH", "L32_COMPLIANT_DISCLAIMER", |
| "L33_COMPLIANT_MARKETING" |
| ] |
|
|
| |
| THRESHOLDS = { |
| "L01_DISEASE_CLAIM": 0.35, "L02_STRUCTURE_FUNCTION": 0.55, |
| "L03_UNSUBSTANTIATED_EFFICACY": 0.15, "L04_DRUG_COMPARISON": 0.125, |
| "L05_DOSING_INSTRUCTIONS": 0.05, "L06_IMPLIED_HUMAN_USE": 0.275, |
| "L07_PHARMA_GRADE_CLAIM": 0.375, "L08_TESTIMONIAL_VIOLATION": 0.25, |
| "L09_MISSING_DISCLAIMER": 0.15, "L10_MISLEADING_DISCLAIMER": 0.5, |
| "L11_MISSING_COA": 0.65, "L12_FDA_FALSE_CLAIM": 0.05, |
| "L13_MISSING_DISCLOSURES": 0.075, "L14_FAIR_BALANCE_VIOLATION": 0.325, |
| "L19_DECEPTIVE_PRICING": 0.05, "L21_CONSUMER_LANGUAGE": 0.075, |
| "L22_BODYBUILDING_CONTENT": 0.075, "L23_META_HEALTH_KEYWORDS": 0.05, |
| "L24_COVER_CONTRADICTION": 0.15, "L25_MISSING_RESEARCH_ID": 0.425, |
| "L27_THCA_LEGALITY_CLAIM": 0.15, "L28_CBD_DISEASE_CLAIM": 0.175, |
| "L31_COMPLIANT_RESEARCH": 0.4, "L32_COMPLIANT_DISCLAIMER": 0.3, |
| "L33_COMPLIANT_MARKETING": 0.05 |
| } |
|
|
| MODEL_ID = "Austin-Groundsetter/deberta-prism-v2" |
|
|
| print("Loading model...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) |
| model.eval() |
| print("Model loaded!") |
|
|
| def predict(text: str) -> dict: |
| """Classify text and return label predictions with confidence scores.""" |
| start = time.time() |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| probs = torch.sigmoid(outputs.logits).squeeze().tolist() |
| latency_ms = round((time.time() - start) * 1000, 1) |
|
|
| triggered = {} |
| all_scores = {} |
| for i, label in enumerate(LABEL_NAMES): |
| score = round(probs[i], 4) |
| threshold = THRESHOLDS[label] |
| all_scores[label] = {"score": score, "threshold": threshold, "triggered": score >= threshold} |
| if score >= threshold: |
| triggered[label] = score |
|
|
| return { |
| "triggered_labels": triggered, |
| "all_scores": all_scores, |
| "latency_ms": latency_ms, |
| "num_triggered": len(triggered) |
| } |
|
|
| def classify_text(text): |
| """Gradio UI wrapper.""" |
| if not text or not text.strip(): |
| return "Please enter text to classify." |
| result = predict(text) |
| output_lines = [f"Latency: {result['latency_ms']}ms", f"Labels triggered: {result['num_triggered']}", ""] |
| if result['triggered_labels']: |
| output_lines.append("TRIGGERED LABELS:") |
| for label, score in sorted(result['triggered_labels'].items(), key=lambda x: -x[1]): |
| output_lines.append(f" {label}: {score:.4f} (threshold: {THRESHOLDS[label]})") |
| else: |
| output_lines.append("NO LABELS TRIGGERED (text appears compliant)") |
| return "\n".join(output_lines) |
|
|
| |
| with gr.Blocks(title="PRISM DeBERTa Classifier") as demo: |
| gr.Markdown("# PRISM DeBERTa v2 — Compliance Classifier") |
| gr.Markdown("25-label multi-label classifier for supplement/hemp/peptide compliance scanning.") |
| with gr.Row(): |
| text_input = gr.Textbox(label="Text to classify", lines=5, placeholder="Enter product page text...") |
| output = gr.Textbox(label="Results", lines=12) |
| btn = gr.Button("Classify", variant="primary") |
| btn.click(fn=classify_text, inputs=text_input, outputs=output) |
|
|
| |
| app = FastAPI() |
|
|
| class PredictRequest(BaseModel): |
| text: str |
|
|
| @app.post("/api/classify") |
| async def api_classify(req: PredictRequest): |
| """REST API endpoint for PRISM scanner. POST {"text": "..."} -> predictions.""" |
| return predict(req.text) |
|
|
| @app.get("/api/health") |
| async def health(): |
| return {"status": "ok", "model": MODEL_ID, "labels": len(LABEL_NAMES)} |
|
|
| |
| app = gr.mount_gradio_app(app, demo, path="/") |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|