narutoSiskovich commited on
Commit
6bb3fde
·
verified ·
1 Parent(s): f0104db

Create agreement_score.py

Browse files
Files changed (1) hide show
  1. agreement_score.py +35 -0
agreement_score.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+
6
+ # === Загружаем модель один раз при старте сервиса ===
7
+ MODEL_NAME = "facebook/bart-large-mnli"
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
10
+ model.eval()
11
+
12
+ # === Создаем FastAPI приложение ===
13
+ app = FastAPI(title="Agreement Checker API")
14
+
15
+ # === Модель запроса ===
16
+ class MessagePair(BaseModel):
17
+ msg1: str
18
+ msg2: str
19
+
20
+ # === Основная логика проверки согласия ===
21
+ def check_agreement(msg1: str, msg2: str) -> float:
22
+ inputs = tokenizer(msg1, msg2, return_tensors="pt", truncation=True)
23
+ with torch.no_grad():
24
+ logits = model(**inputs).logits
25
+ probs = torch.softmax(logits, dim=-1)
26
+ entailment_prob = probs[0][2].item() # entailment
27
+ contradiction_prob = probs[0][0].item() # contradiction
28
+ score = entailment_prob - contradiction_prob
29
+ return round(score, 2)
30
+
31
+ # === Эндпоинт API ===
32
+ @app.post("/agreement")
33
+ def agreement(pair: MessagePair):
34
+ score = check_agreement(pair.msg1, pair.msg2)
35
+ return {"agreement_score": score}