chandkr123's picture
Upload api/main.py with huggingface_hub
b676669 verified
Raw
History Blame Contribute Delete
1.75 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline
import torch, json, datetime, os
app = FastAPI(title="Toxic Comment Classifier", version="1.0")
HF_MODEL_ID = "chandkr123/toxic-comment-classifier"
print(f"Loading model from {HF_MODEL_ID}...")
device = 0 if torch.cuda.is_available() else -1
pipe = pipeline(
"text-classification",
model=HF_MODEL_ID,
top_k=None,
function_to_apply="sigmoid",
device=device
)
print("✅ Model loaded")
LABELS = ["toxic","severe_toxic","obscene","threat","insult","identity_hate"]
LOG_FILE = "predictions.jsonl"
class PredictRequest(BaseModel):
text: str
threshold: float = 0.5
class PredictResponse(BaseModel):
text: str
predictions: dict
flagged: list
latency_ms: float
@app.get("/")
def root():
return {"message": "Toxic Comment Classifier API", "docs": "/docs"}
@app.get("/health")
def health():
return {"status": "ok", "model": HF_MODEL_ID}
@app.post("/predict", response_model=PredictResponse)
def predict(req: PredictRequest):
import time
t0 = time.perf_counter()
result = pipe(req.text)[0]
scores = {r["label"]: round(r["score"], 4) for r in result}
flagged = [l for l, s in scores.items() if s > req.threshold]
latency = round((time.perf_counter() - t0) * 1000, 2)
log_entry = {
"ts": datetime.datetime.utcnow().isoformat(),
"text": req.text[:200],
"scores": scores,
"flagged": flagged,
"latency_ms": latency
}
with open(LOG_FILE, "a") as f:
f.write(json.dumps(log_entry) + "\n")
return PredictResponse(
text=req.text,
predictions=scores,
flagged=flagged,
latency_ms=latency
)