Austin-Groundsetter's picture
Upload app.py with huggingface_hub
97e01b5 verified
import gradio as gr
import torch
import json
import time
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from fastapi import FastAPI
from pydantic import BaseModel
# Label names in order
LABEL_NAMES = [
"L01_DISEASE_CLAIM", "L02_STRUCTURE_FUNCTION", "L03_UNSUBSTANTIATED_EFFICACY",
"L04_DRUG_COMPARISON", "L05_DOSING_INSTRUCTIONS", "L06_IMPLIED_HUMAN_USE",
"L07_PHARMA_GRADE_CLAIM", "L08_TESTIMONIAL_VIOLATION", "L09_MISSING_DISCLAIMER",
"L10_MISLEADING_DISCLAIMER", "L11_MISSING_COA", "L12_FDA_FALSE_CLAIM",
"L13_MISSING_DISCLOSURES", "L14_FAIR_BALANCE_VIOLATION", "L19_DECEPTIVE_PRICING",
"L21_CONSUMER_LANGUAGE", "L22_BODYBUILDING_CONTENT", "L23_META_HEALTH_KEYWORDS",
"L24_COVER_CONTRADICTION", "L25_MISSING_RESEARCH_ID", "L27_THCA_LEGALITY_CLAIM",
"L28_CBD_DISEASE_CLAIM", "L31_COMPLIANT_RESEARCH", "L32_COMPLIANT_DISCLAIMER",
"L33_COMPLIANT_MARKETING"
]
# Optimized per-label thresholds from training
THRESHOLDS = {
"L01_DISEASE_CLAIM": 0.35, "L02_STRUCTURE_FUNCTION": 0.55,
"L03_UNSUBSTANTIATED_EFFICACY": 0.15, "L04_DRUG_COMPARISON": 0.125,
"L05_DOSING_INSTRUCTIONS": 0.05, "L06_IMPLIED_HUMAN_USE": 0.275,
"L07_PHARMA_GRADE_CLAIM": 0.375, "L08_TESTIMONIAL_VIOLATION": 0.25,
"L09_MISSING_DISCLAIMER": 0.15, "L10_MISLEADING_DISCLAIMER": 0.5,
"L11_MISSING_COA": 0.65, "L12_FDA_FALSE_CLAIM": 0.05,
"L13_MISSING_DISCLOSURES": 0.075, "L14_FAIR_BALANCE_VIOLATION": 0.325,
"L19_DECEPTIVE_PRICING": 0.05, "L21_CONSUMER_LANGUAGE": 0.075,
"L22_BODYBUILDING_CONTENT": 0.075, "L23_META_HEALTH_KEYWORDS": 0.05,
"L24_COVER_CONTRADICTION": 0.15, "L25_MISSING_RESEARCH_ID": 0.425,
"L27_THCA_LEGALITY_CLAIM": 0.15, "L28_CBD_DISEASE_CLAIM": 0.175,
"L31_COMPLIANT_RESEARCH": 0.4, "L32_COMPLIANT_DISCLAIMER": 0.3,
"L33_COMPLIANT_MARKETING": 0.05
}
MODEL_ID = "Austin-Groundsetter/deberta-prism-v2"
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.eval()
print("Model loaded!")
def predict(text: str) -> dict:
"""Classify text and return label predictions with confidence scores."""
start = time.time()
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.sigmoid(outputs.logits).squeeze().tolist()
latency_ms = round((time.time() - start) * 1000, 1)
triggered = {}
all_scores = {}
for i, label in enumerate(LABEL_NAMES):
score = round(probs[i], 4)
threshold = THRESHOLDS[label]
all_scores[label] = {"score": score, "threshold": threshold, "triggered": score >= threshold}
if score >= threshold:
triggered[label] = score
return {
"triggered_labels": triggered,
"all_scores": all_scores,
"latency_ms": latency_ms,
"num_triggered": len(triggered)
}
def classify_text(text):
"""Gradio UI wrapper."""
if not text or not text.strip():
return "Please enter text to classify."
result = predict(text)
output_lines = [f"Latency: {result['latency_ms']}ms", f"Labels triggered: {result['num_triggered']}", ""]
if result['triggered_labels']:
output_lines.append("TRIGGERED LABELS:")
for label, score in sorted(result['triggered_labels'].items(), key=lambda x: -x[1]):
output_lines.append(f" {label}: {score:.4f} (threshold: {THRESHOLDS[label]})")
else:
output_lines.append("NO LABELS TRIGGERED (text appears compliant)")
return "\n".join(output_lines)
# Gradio UI
with gr.Blocks(title="PRISM DeBERTa Classifier") as demo:
gr.Markdown("# PRISM DeBERTa v2 — Compliance Classifier")
gr.Markdown("25-label multi-label classifier for supplement/hemp/peptide compliance scanning.")
with gr.Row():
text_input = gr.Textbox(label="Text to classify", lines=5, placeholder="Enter product page text...")
output = gr.Textbox(label="Results", lines=12)
btn = gr.Button("Classify", variant="primary")
btn.click(fn=classify_text, inputs=text_input, outputs=output)
# Mount custom FastAPI REST endpoint for scanner integration
app = FastAPI()
class PredictRequest(BaseModel):
text: str
@app.post("/api/classify")
async def api_classify(req: PredictRequest):
"""REST API endpoint for PRISM scanner. POST {"text": "..."} -> predictions."""
return predict(req.text)
@app.get("/api/health")
async def health():
return {"status": "ok", "model": MODEL_ID, "labels": len(LABEL_NAMES)}
# Mount Gradio app on FastAPI
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)