ABSA-REST-API / app.py
EfektMotyla's picture
Update app.py
9d9d143 verified
raw
history blame
5.11 kB
from pathlib import Path
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import torch
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
AutoModelForSequenceClassification,
pipeline,
)
# ────────────────────── konfiguracja ──────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
ROOT = Path(__file__).parent
MODELS_DIR = ROOT / "models"
aspect_dir = MODELS_DIR / "bert-aspect-ner"
sentiment_dir = MODELS_DIR / "absa-roberta"
# ────────────────────── modele lokalne ─────────────────────
aspect_tokenizer = AutoTokenizer.from_pretrained(
str(aspect_dir), local_files_only=True, use_fast=False # ← jeΕ›li brak tokenizer.json
)
aspect_model = AutoModelForTokenClassification.from_pretrained(
str(aspect_dir), local_files_only=True
).to(device)
sentiment_tokenizer = AutoTokenizer.from_pretrained(
str(sentiment_dir), local_files_only=True
)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(
str(sentiment_dir), local_files_only=True
).to(device)
# ────────────────────── modele tΕ‚umaczeΕ„ (on-line) ─────────
pl_to_en = pipeline(
"translation",
model="Helsinki-NLP/opus-mt-pl-en",
device=0 if device == "cuda" else -1,
)
en_to_pl = pipeline(
"translation",
model="gsarti/opus-mt-tc-en-pl",
device=0 if device == "cuda" else -1,
)
# ────────────────────── schemy Pydantic ────────────────────
class Comment(BaseModel):
text: str
class AspectSentiment(BaseModel):
aspect: str
sentiment: str
class AnalysisResult(BaseModel):
results: List[AspectSentiment]
# === Słownik aliasów aspektów EN→PL (taki sam jak wcześniej) ===
aspect_aliases = {
"food": "jedzenie", "service": "obsΕ‚uga", "price": "cena",
"taste": "smak", "waiter": "obsΕ‚uga", "dish": "danie",
"portion": "porcja", "staff": "obsΕ‚uga", "decor": "wystrΓ³j",
"menu": "menu", "drink": "napoje", "location": "lokalizacja",
"time": "czas oczekiwania", "cleanliness": "czystoΕ›Δ‡", "smell": "zapach",
"value": "cena", "experience": "doΕ›wiadczenie", "recommendation": "ogΓ³lna ocena",
"children": "dzieci", "family": "rodzina", "pet": "zwierzΔ™ta"
}
def translate_pl_to_en(texts):
return [res["translation_text"] for res in pl_to_en(texts)]
def translate_en_to_pl(texts):
return [res["translation_text"] for res in en_to_pl(texts)]
def extract_aspects(text_en: str):
inputs = aspect_tokenizer(
text_en, return_tensors="pt", truncation=True, padding=True
).to(device)
with torch.no_grad():
outputs = aspect_model(**inputs)
preds = torch.argmax(outputs.logits, dim=2)[0].cpu().numpy()
tokens = aspect_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
labels = [aspect_model.config.id2label[p] for p in preds]
aspects, current_tokens = [], []
for token, label in zip(tokens, labels):
if label == "B-ASP":
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
current_tokens = [token]
elif label == "I-ASP" and current_tokens:
current_tokens.append(token)
else:
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
current_tokens = []
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
# ↓ usuΕ„ spacje z β€ž##” i zduplikowane wyniki
return list({tok.replace(" ##", "") for tok in aspects})
# ────────────────────── FastAPI ────────────────────────────
app = FastAPI()
@app.post("/analyze", response_model=AnalysisResult)
def analyze_comment(comment: Comment):
text_pl = comment.text
text_en = translate_pl_to_en([text_pl])[0]
aspects = extract_aspects(text_en)
results: list[AspectSentiment] = []
for asp in aspects:
input_text = f"{text_en} [SEP] {asp}"
inputs = sentiment_tokenizer(
input_text, return_tensors="pt", truncation=True, padding=True
).to(device)
with torch.no_grad():
logits = sentiment_model(**inputs).logits
predicted_class_id = int(logits.argmax().cpu())
sentiment_label = {
0: "negatywny",
1: "neutralny",
2: "pozytywny",
3: "konfliktowy",
}[predicted_class_id]
asp_pl = aspect_aliases.get(asp, translate_en_to_pl([asp])[0].lower())
results.append(AspectSentiment(aspect=asp_pl, sentiment=sentiment_label))
return {"results": results}