Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import List | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| XLMRobertaForSequenceClassification, | |
| ) | |
| app = FastAPI(title="Unified NLP API") | |
| # ===================== | |
| # Agreement (MNLI) | |
| # ===================== | |
| MNLI_MODEL = "facebook/bart-base-mnli" | |
| mnli_tokenizer = None | |
| mnli_model = None | |
| def load_mnli(): | |
| global mnli_tokenizer, mnli_model | |
| if mnli_model is None: | |
| mnli_tokenizer = AutoTokenizer.from_pretrained(MNLI_MODEL) | |
| mnli_model = AutoModelForSequenceClassification.from_pretrained(MNLI_MODEL) | |
| mnli_model.eval() | |
| def check_agreement(msg1: str, msg2: str) -> float: | |
| load_mnli() | |
| inputs = mnli_tokenizer(msg1, msg2, return_tensors="pt", truncation=True) | |
| with torch.no_grad(): | |
| logits = mnli_model(**inputs).logits | |
| probs = torch.softmax(logits, dim=-1)[0] | |
| return round((probs[2] - probs[0]).item(), 2) # entailment - contradiction | |
| # ===================== | |
| # Sentiment | |
| # ===================== | |
| SENTIMENT_MODEL = "nlptown/bert-base-multilingual-uncased-sentiment" | |
| sent_tokenizer = None | |
| sent_model = None | |
| def load_sentiment(): | |
| global sent_tokenizer, sent_model | |
| if sent_model is None: | |
| sent_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL) | |
| sent_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL) | |
| sent_model.eval() | |
| def analyze_sentiment(text: str) -> float: | |
| load_sentiment() | |
| inputs = sent_tokenizer(text, return_tensors="pt", truncation=True) | |
| with torch.no_grad(): | |
| logits = sent_model(**inputs).logits | |
| probs = torch.softmax(logits, dim=-1) | |
| stars = torch.argmax(probs, dim=-1).item() + 1 | |
| return round((stars - 3) * 2.5, 2) # -5 .. +5 | |
| # ===================== | |
| # Multilabel classifier | |
| # ===================== | |
| CLASSIFIER_MODEL = "xlm-roberta-base" | |
| CATEGORIES = [ | |
| "politique", "woke", "racism", "crime", | |
| "police_abuse", "corruption", "hate_speech", "activism" | |
| ] | |
| clf_tokenizer = None | |
| clf_model = None | |
| def load_classifier(): | |
| global clf_tokenizer, clf_model | |
| if clf_model is None: | |
| clf_tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL) | |
| clf_model = XLMRobertaForSequenceClassification.from_pretrained( | |
| CLASSIFIER_MODEL, | |
| num_labels=len(CATEGORIES) | |
| ) | |
| clf_model.eval() | |
| def classify_message(text: str) -> List[str]: | |
| load_classifier() | |
| inputs = clf_tokenizer(text, return_tensors="pt", truncation=True) | |
| with torch.no_grad(): | |
| logits = clf_model(**inputs).logits | |
| probs = torch.sigmoid(logits)[0] | |
| labels = [CATEGORIES[i] for i, p in enumerate(probs) if p > 0.5] | |
| return labels or ["neutral"] | |
| # ===================== | |
| # API schemas | |
| # ===================== | |
| class AgreementRequest(BaseModel): | |
| msg1: str | |
| msg2: str | |
| class TextRequest(BaseModel): | |
| text: str | |
| # ===================== | |
| # Endpoints | |
| # ===================== | |
| def agreement(req: AgreementRequest): | |
| return {"agreement_score": check_agreement(req.msg1, req.msg2)} | |
| def sentiment(req: TextRequest): | |
| return {"sentiment_score": analyze_sentiment(req.text)} | |
| def classify(req: TextRequest): | |
| return {"categories": classify_message(req.text)} | |
| def root(): | |
| return { | |
| "status": "ok", | |
| "endpoints": { | |
| "POST /sentiment": "sentiment analysis", | |
| "POST /agreement": "text agreement", | |
| "POST /classify": "multilabel classification", | |
| "GET /docs": "swagger UI" | |
| } | |
| } | |