ABSA-REST-API / app.py
EfektMotyla's picture
Update app.py
668f19f verified
raw
history blame
6.08 kB
from pathlib import Path
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
from transformers import MarianMTModel, MarianTokenizer
import torch
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
AutoModelForSequenceClassification,
pipeline,
)
import os
# ────────────────────── konfiguracja ──────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
ROOT = Path(__file__).parent
aspect_dir = ROOT / "bert-aspect-ner"
sentiment_dir = ROOT / "absa-roberta"
device = "cuda" if torch.cuda.is_available() else "cpu"
hf_token = os.getenv("HF_TOKEN")
# ────────────────────── modele lokalne ─────────────────────
aspect_tokenizer = AutoTokenizer.from_pretrained(
str(aspect_dir), local_files_only=True, use_fast=False # ← jeΕ›li brak tokenizer.json
)
aspect_model = AutoModelForTokenClassification.from_pretrained(
str(aspect_dir), local_files_only=True
).to(device)
sentiment_tokenizer = AutoTokenizer.from_pretrained(
str(sentiment_dir), local_files_only=True
)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(
str(sentiment_dir), local_files_only=True
).to(device)
# ────────────────────── modele tΕ‚umaczeΕ„ (on-line) ─────────
pl_to_en_dir = snapshot_download("Helsinki-NLP/opus-mt-pl-en", token=hf_token)
en_to_pl_dir = snapshot_download("gsarti/opus-mt-tc-en-pl", token=hf_token)
# πŸ“š Ładowanie tokenizerΓ³w i modeli
pl_to_en_tokenizer = MarianTokenizer.from_pretrained(pl_to_en_dir)
pl_to_en_model = MarianMTModel.from_pretrained(pl_to_en_dir).to(device)
en_to_pl_tokenizer = MarianTokenizer.from_pretrained(en_to_pl_dir)
en_to_pl_model = MarianMTModel.from_pretrained(en_to_pl_dir).to(device)
# πŸ” Funkcje tΕ‚umaczeΕ„
def translate_pl_to_en(texts):
inputs = pl_to_en_tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
translated = pl_to_en_model.generate(**inputs)
return pl_to_en_tokenizer.batch_decode(translated, skip_special_tokens=True)
def translate_en_to_pl(texts):
inputs = en_to_pl_tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
translated = en_to_pl_model.generate(**inputs)
return en_to_pl_tokenizer.batch_decode(translated, skip_special_tokens=True)
# ────────────────────── schemy Pydantic ────────────────────
class Comment(BaseModel):
text: str
class AspectSentiment(BaseModel):
aspect: str
sentiment: str
class AnalysisResult(BaseModel):
results: List[AspectSentiment]
# === Słownik aliasów aspektów EN→PL (taki sam jak wcześniej) ===
aspect_aliases = {
"food": "jedzenie", "service": "obsΕ‚uga", "price": "cena",
"taste": "smak", "waiter": "obsΕ‚uga", "dish": "danie",
"portion": "porcja", "staff": "obsΕ‚uga", "decor": "wystrΓ³j",
"menu": "menu", "drink": "napoje", "location": "lokalizacja",
"time": "czas oczekiwania", "cleanliness": "czystoΕ›Δ‡", "smell": "zapach",
"value": "cena", "experience": "doΕ›wiadczenie", "recommendation": "ogΓ³lna ocena",
"children": "dzieci", "family": "rodzina", "pet": "zwierzΔ™ta"
}
def translate_pl_to_en(texts):
return [res["translation_text"] for res in pl_to_en(texts)]
def translate_en_to_pl(texts):
return [res["translation_text"] for res in en_to_pl(texts)]
def extract_aspects(text_en: str):
inputs = aspect_tokenizer(
text_en, return_tensors="pt", truncation=True, padding=True
).to(device)
with torch.no_grad():
outputs = aspect_model(**inputs)
preds = torch.argmax(outputs.logits, dim=2)[0].cpu().numpy()
tokens = aspect_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
labels = [aspect_model.config.id2label[p] for p in preds]
aspects, current_tokens = [], []
for token, label in zip(tokens, labels):
if label == "B-ASP":
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
current_tokens = [token]
elif label == "I-ASP" and current_tokens:
current_tokens.append(token)
else:
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
current_tokens = []
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
# ↓ usuΕ„ spacje z β€ž##” i zduplikowane wyniki
return list({tok.replace(" ##", "") for tok in aspects})
# ────────────────────── FastAPI ────────────────────────────
app = FastAPI()
@app.post("/analyze", response_model=AnalysisResult)
def analyze_comment(comment: Comment):
text_pl = comment.text
text_en = translate_pl_to_en([text_pl])[0]
aspects = extract_aspects(text_en)
results: list[AspectSentiment] = []
for asp in aspects:
input_text = f"{text_en} [SEP] {asp}"
inputs = sentiment_tokenizer(
input_text, return_tensors="pt", truncation=True, padding=True
).to(device)
with torch.no_grad():
logits = sentiment_model(**inputs).logits
predicted_class_id = int(logits.argmax().cpu())
sentiment_label = {
0: "negatywny",
1: "neutralny",
2: "pozytywny",
3: "konfliktowy",
}[predicted_class_id]
asp_pl = aspect_aliases.get(asp, translate_en_to_pl([asp])[0].lower())
results.append(AspectSentiment(aspect=asp_pl, sentiment=sentiment_label))
return {"results": results}