ABSA-REST-API / app.py
EfektMotyla's picture
Update app.py
d392497 verified
from pathlib import Path
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
from transformers import MarianMTModel, MarianTokenizer
import torch
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
AutoModelForSequenceClassification,
pipeline,
)
import os
from huggingface_hub import snapshot_download
import logging
# Konfiguracja logowania
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ────────────────────── konfiguracja ──────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
ROOT = Path(__file__).parent
aspect_dir = ROOT / "bert-aspect-ner"
sentiment_dir = ROOT / "absa-roberta"
device = "cuda" if torch.cuda.is_available() else "cpu"
hf_token = os.getenv("HF_TOKEN")
# ────────────────────── modele lokalne ─────────────────────
aspect_tokenizer = AutoTokenizer.from_pretrained(
str(aspect_dir), local_files_only=True, use_fast=False # ← jeΕ›li brak tokenizer.json
)
aspect_model = AutoModelForTokenClassification.from_pretrained(
str(aspect_dir), local_files_only=True
).to(device)
sentiment_tokenizer = AutoTokenizer.from_pretrained(
str(sentiment_dir), local_files_only=True
)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(
str(sentiment_dir), local_files_only=True
).to(device)
# ────────────────────── modele tΕ‚umaczeΕ„ (on-line) ─────────
HF_CACHE_DIR = "/tmp/hf_cache"
os.makedirs(HF_CACHE_DIR, exist_ok=True)
os.environ["HF_HOME"] = HF_CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE_DIR
# Pobieramy modele
pl_to_en_dir = snapshot_download(
"Helsinki-NLP/opus-mt-pl-en", token=hf_token, cache_dir=HF_CACHE_DIR
)
en_to_pl_dir = snapshot_download(
"gsarti/opus-mt-tc-en-pl", token=hf_token, cache_dir=HF_CACHE_DIR
)
# Ładujemy
pl_to_en_tok = MarianTokenizer.from_pretrained(pl_to_en_dir)
pl_to_en_mod = MarianMTModel.from_pretrained(pl_to_en_dir).to(device)
en_to_pl_tok = MarianTokenizer.from_pretrained(en_to_pl_dir)
en_to_pl_mod = MarianMTModel.from_pretrained(en_to_pl_dir).to(device)
# ────────────────────── schemy Pydantic ────────────────────
class Comment(BaseModel):
text: str
class AspectSentiment(BaseModel):
aspect: str
sentiment: str
class AnalysisResult(BaseModel):
results: List[AspectSentiment]
# === Słownik aliasów aspektów EN→PL (taki sam jak wcześniej) ===
aspect_aliases = {
"food": "jedzenie", "service": "obsΕ‚uga", "price": "cena",
"taste": "smak", "waiter": "obsΕ‚uga", "dish": "danie",
"portion": "porcja", "staff": "obsΕ‚uga", "decor": "wystrΓ³j",
"menu": "menu", "drink": "napoje", "location": "lokalizacja",
"time": "czas oczekiwania", "cleanliness": "czystoΕ›Δ‡", "smell": "zapach",
"value": "cena", "experience": "doΕ›wiadczenie", "recommendation": "ogΓ³lna ocena",
"children": "dzieci", "family": "rodzina", "pet": "zwierzΔ™ta"
}
# ───────────────────── tΕ‚umaczenia ──────────────────────
def translate_pl_to_en(texts: list[str]) -> list[str]:
inputs = pl_to_en_tok(texts,
return_tensors="pt",
padding=True,
truncation=True).to(device)
with torch.no_grad():
generated = pl_to_en_mod.generate(**inputs)
return pl_to_en_tok.batch_decode(generated, skip_special_tokens=True)
def translate_en_to_pl(texts: list[str]) -> list[str]:
inputs = en_to_pl_tok(texts,
return_tensors="pt",
padding=True,
truncation=True).to(device)
with torch.no_grad():
generated = en_to_pl_mod.generate(**inputs)
return en_to_pl_tok.batch_decode(generated, skip_special_tokens=True)
def extract_aspects(text_en: str):
inputs = aspect_tokenizer(
text_en, return_tensors="pt", truncation=True, padding=True
).to(device)
with torch.no_grad():
outputs = aspect_model(**inputs)
preds = torch.argmax(outputs.logits, dim=2)[0].cpu().numpy()
tokens = aspect_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
labels = [aspect_model.config.id2label[p] for p in preds]
aspects, current_tokens = [], []
for token, label in zip(tokens, labels):
if label == "B-ASP":
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
current_tokens = [token]
elif label == "I-ASP" and current_tokens:
current_tokens.append(token)
else:
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
current_tokens = []
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
# ↓ usuΕ„ spacje z β€ž##” i zduplikowane wyniki
return list({tok.replace(" ##", "") for tok in aspects})
# ────────────────────── FastAPI ────────────────────────────
app = FastAPI()
@app.post("/analyze", response_model=AnalysisResult)
def analyze_comment(comment: Comment):
logger.info(f"Otrzymano zapytanie: {comment.text}")
text_pl = comment.text
text_en = translate_pl_to_en([text_pl])[0]
aspects = extract_aspects(text_en)
results: list[AspectSentiment] = []
for asp in aspects:
input_text = f"{text_en} [SEP] {asp}"
inputs = sentiment_tokenizer(
input_text, return_tensors="pt", truncation=True, padding=True
).to(device)
with torch.no_grad():
logits = sentiment_model(**inputs).logits
predicted_class_id = int(logits.argmax().cpu())
sentiment_label = {
0: "negatywny",
1: "neutralny",
2: "pozytywny",
3: "konfliktowy",
}[predicted_class_id]
asp_pl = aspect_aliases.get(asp, translate_en_to_pl([asp])[0].lower())
results.append(AspectSentiment(aspect=asp_pl, sentiment=sentiment_label))
logger.info(f"WysΕ‚ano odpowiedΕΊ: {results} dla zapytania: {comment.text}")
return {"results": results}