classifier / app.py
narutoSiskovich's picture
Update app.py
99aa3de verified
raw
history blame
3.67 kB
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
XLMRobertaForSequenceClassification,
)
app = FastAPI(title="Unified NLP API")
# =====================
# Agreement (MNLI)
# =====================
MNLI_MODEL = "facebook/bart-base-mnli"
mnli_tokenizer = None
mnli_model = None
def load_mnli():
global mnli_tokenizer, mnli_model
if mnli_model is None:
mnli_tokenizer = AutoTokenizer.from_pretrained(MNLI_MODEL)
mnli_model = AutoModelForSequenceClassification.from_pretrained(MNLI_MODEL)
mnli_model.eval()
def check_agreement(msg1: str, msg2: str) -> float:
load_mnli()
inputs = mnli_tokenizer(msg1, msg2, return_tensors="pt", truncation=True)
with torch.no_grad():
logits = mnli_model(**inputs).logits
probs = torch.softmax(logits, dim=-1)[0]
return round((probs[2] - probs[0]).item(), 2) # entailment - contradiction
# =====================
# Sentiment
# =====================
SENTIMENT_MODEL = "nlptown/bert-base-multilingual-uncased-sentiment"
sent_tokenizer = None
sent_model = None
def load_sentiment():
global sent_tokenizer, sent_model
if sent_model is None:
sent_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL)
sent_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL)
sent_model.eval()
def analyze_sentiment(text: str) -> float:
load_sentiment()
inputs = sent_tokenizer(text, return_tensors="pt", truncation=True)
with torch.no_grad():
logits = sent_model(**inputs).logits
probs = torch.softmax(logits, dim=-1)
stars = torch.argmax(probs, dim=-1).item() + 1
return round((stars - 3) * 2.5, 2) # -5 .. +5
# =====================
# Multilabel classifier
# =====================
CLASSIFIER_MODEL = "xlm-roberta-base"
CATEGORIES = [
"politique", "woke", "racism", "crime",
"police_abuse", "corruption", "hate_speech", "activism"
]
clf_tokenizer = None
clf_model = None
def load_classifier():
global clf_tokenizer, clf_model
if clf_model is None:
clf_tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL)
clf_model = XLMRobertaForSequenceClassification.from_pretrained(
CLASSIFIER_MODEL,
num_labels=len(CATEGORIES)
)
clf_model.eval()
def classify_message(text: str) -> List[str]:
load_classifier()
inputs = clf_tokenizer(text, return_tensors="pt", truncation=True)
with torch.no_grad():
logits = clf_model(**inputs).logits
probs = torch.sigmoid(logits)[0]
labels = [CATEGORIES[i] for i, p in enumerate(probs) if p > 0.5]
return labels or ["neutral"]
# =====================
# API schemas
# =====================
class AgreementRequest(BaseModel):
msg1: str
msg2: str
class TextRequest(BaseModel):
text: str
# =====================
# Endpoints
# =====================
@app.post("/agreement")
def agreement(req: AgreementRequest):
return {"agreement_score": check_agreement(req.msg1, req.msg2)}
@app.post("/sentiment")
def sentiment(req: TextRequest):
return {"sentiment_score": analyze_sentiment(req.text)}
@app.post("/classify")
def classify(req: TextRequest):
return {"categories": classify_message(req.text)}
@app.get("/")
def root():
return {
"status": "ok",
"endpoints": {
"POST /sentiment": "sentiment analysis",
"POST /agreement": "text agreement",
"POST /classify": "multilabel classification",
"GET /docs": "swagger UI"
}
}