Demo / app.py
cjell
fixing model output formats
10f7d04
raw
history blame
3.57 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline
from datetime import datetime
import os
os.environ["HF_HOME"] = "/tmp"
SPAM_MODEL = "valurank/distilroberta-spam-comments-detection"
TOXIC_MODEL = "s-nlp/roberta_toxicity_classifier"
SENTIMENT_MODEL = "nlptown/bert-base-multilingual-uncased-sentiment"
NSFW_MODEL = "michellejieli/NSFW_text_classifier"
# Load models
spam = pipeline("text-classification", model=SPAM_MODEL)
toxic = pipeline("text-classification", model=TOXIC_MODEL)
sentiment = pipeline("text-classification", model=SENTIMENT_MODEL)
nsfw = pipeline("text-classification", model=NSFW_MODEL)
app = FastAPI(title="Plebzs AI Models API")
class Query(BaseModel):
text: str
@app.get("/")
def root():
return {"status": "ok", "message": "Plebzs AI Models API"}
# Required by Plebzs boss
@app.get("/moderation/ping")
def moderation_ping():
return {
"status": "healthy",
"models": ["spam", "toxic", "sentiment", "nsfw"],
"timestamp": datetime.now().isoformat(),
"version": "1.0.0"
}
# Main endpoints - formatted for Plebzs
@app.post("/toxicity") # Changed name to match Plebzs expectation
def predict_toxicity(query: Query):
result = toxic(query.text)[0]
# Convert to 0-1 toxicity scale
toxicity_score = result["score"] if result["label"] == "TOXIC" else 1 - result["score"]
return {
"toxicity_score": round(toxicity_score, 3),
"confidence": round(result["score"], 3),
"raw_output": result
}
@app.post("/sentiment")
def predict_sentiment(query: Query):
result = sentiment(query.text)[0]
# Convert star rating to -1 to 1 scale
label = result["label"]
if "1" in label or "2" in label: # 1-2 stars = negative
sentiment_score = -0.7
elif "3" in label: # 3 stars = neutral
sentiment_score = 0.0
else: # 4-5 stars = positive
sentiment_score = 0.7
return {
"sentiment_score": round(sentiment_score, 3),
"confidence": round(result["score"], 3),
"raw_output": result
}
# Bonus endpoints (not used by Plebzs yet, but good to have)
@app.post("/spam")
def predict_spam(query: Query):
result = spam(query.text)[0]
spam_score = result["score"] if result["label"] == "SPAM" else 1 - result["score"]
return {
"spam_score": round(spam_score, 3),
"confidence": round(result["score"], 3),
"raw_output": result
}
@app.post("/nsfw")
def predict_nsfw(query: Query):
result = nsfw(query.text)[0]
nsfw_score = result["score"] if result["label"] == "NSFW" else 1 - result["score"]
return {
"nsfw_score": round(nsfw_score, 3),
"confidence": round(result["score"], 3),
"raw_output": result
}
# Keep your detailed health check
@app.get("/health")
def health_check():
status = {
"server": "running",
"models": {}
}
models = {
"spam": (SPAM_MODEL, spam),
"toxic": (TOXIC_MODEL, toxic),
"sentiment": (SENTIMENT_MODEL, sentiment),
"nsfw": (NSFW_MODEL, nsfw),
}
for key, (model_name, model_pipeline) in models.items():
try:
model_pipeline("test")
status["models"][key] = {
"model_name": model_name,
"status": "running"
}
except Exception as e:
status["models"][key] = {
"model_name": model_name,
"status": f"error: {str(e)}"
}
return status