Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # === Загружаем модель один раз при старте сервиса === | |
| MODEL_NAME = "facebook/bart-large-mnli" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
| model.eval() | |
| # === Создаем FastAPI приложение === | |
| app = FastAPI(title="Agreement Checker API") | |
| # === Модель запроса === | |
| class MessagePair(BaseModel): | |
| msg1: str | |
| msg2: str | |
| # === Основная логика проверки согласия === | |
| def check_agreement(msg1: str, msg2: str) -> float: | |
| inputs = tokenizer(msg1, msg2, return_tensors="pt", truncation=True) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = torch.softmax(logits, dim=-1) | |
| entailment_prob = probs[0][2].item() # entailment | |
| contradiction_prob = probs[0][0].item() # contradiction | |
| score = entailment_prob - contradiction_prob | |
| return round(score, 2) | |
| # === Эндпоинт API === | |
| def agreement(pair: MessagePair): | |
| score = check_agreement(pair.msg1, pair.msg2) | |
| return {"agreement_score": score} | |