classifier / agreement_score.py
narutoSiskovich's picture
Create agreement_score.py
6bb3fde verified
raw
history blame
1.29 kB
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# === Загружаем модель один раз при старте сервиса ===
MODEL_NAME = "facebook/bart-large-mnli"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
# === Создаем FastAPI приложение ===
app = FastAPI(title="Agreement Checker API")
# === Модель запроса ===
class MessagePair(BaseModel):
msg1: str
msg2: str
# === Основная логика проверки согласия ===
def check_agreement(msg1: str, msg2: str) -> float:
inputs = tokenizer(msg1, msg2, return_tensors="pt", truncation=True)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)
entailment_prob = probs[0][2].item() # entailment
contradiction_prob = probs[0][0].item() # contradiction
score = entailment_prob - contradiction_prob
return round(score, 2)
# === Эндпоинт API ===
@app.post("/agreement")
def agreement(pair: MessagePair):
score = check_agreement(pair.msg1, pair.msg2)
return {"agreement_score": score}