ABSA-REST-API / app.py
EfektMotyla's picture
Update app.py
d9095b6 verified
raw
history blame
4 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
from typing import List
from transformers import (
AutoTokenizer, AutoModelForTokenClassification,
AutoModelForSequenceClassification, pipeline
)
import torch
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
aspect_tokenizer = AutoTokenizer.from_pretrained("EfektMotyla/bert-aspect-ner")
aspect_model = AutoModelForTokenClassification.from_pretrained("EfektMotyla/bert-aspect-ner").to(device)
sentiment_tokenizer = AutoTokenizer.from_pretrained("EfektMotyla/absa-roberta")
sentiment_model = AutoModelForSequenceClassification.from_pretrained("EfektMotyla/absa-roberta").to(device)
pl_to_en = pipeline(
"translation",
model="Helsinki-NLP/opus-mt-pl-en",
device=0 if device == "cuda" else -1
)
en_to_pl = pipeline(
"translation",
model="gsarti/opus-mt-tc-en-pl",
device=0 if device == "cuda" else -1
)
# === Dane wej艣ciowe i wyj艣ciowe ===
class Comment(BaseModel):
text: str
class AspectSentiment(BaseModel):
aspect: str
sentiment: str
class AnalysisResult(BaseModel):
results: List[AspectSentiment]
# === S艂ownik alias贸w aspekt贸w EN鈫扨L (taki sam jak wcze艣niej) ===
aspect_aliases = {
"food": "jedzenie", "service": "obs艂uga", "price": "cena",
"taste": "smak", "waiter": "obs艂uga", "dish": "danie",
"portion": "porcja", "staff": "obs艂uga", "decor": "wystr贸j",
"menu": "menu", "drink": "napoje", "location": "lokalizacja",
"time": "czas oczekiwania", "cleanliness": "czysto艣膰", "smell": "zapach",
"value": "cena", "experience": "do艣wiadczenie", "recommendation": "og贸lna ocena",
"children": "dzieci", "family": "rodzina", "pet": "zwierz臋ta"
# dodaj wi臋cej jak chcesz
}
# === Funkcje pomocnicze ===
def translate_pl_to_en(texts):
return [res["translation_text"] for res in pl_to_en(texts)]
def translate_en_to_pl(texts):
return [res["translation_text"] for res in en_to_pl(texts)]
def extract_aspects(text_en):
inputs = aspect_tokenizer(text_en, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = aspect_model(**inputs)
preds = torch.argmax(outputs.logits, dim=2)[0].cpu().numpy()
tokens = aspect_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
labels = [aspect_model.config.id2label[p] for p in preds]
aspects = []
current_tokens = []
for token, label in zip(tokens, labels):
if label == "B-ASP":
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
current_tokens = [token]
elif label == "I-ASP" and current_tokens:
current_tokens.append(token)
else:
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
current_tokens = []
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
return list(set([a.lower() for a in aspects]))
# === G艂贸wna funkcja API ===
app = FastAPI()
@app.post("/analyze", response_model=AnalysisResult)
def analyze_comment(comment: Comment):
text_pl = comment.text
text_en = translate_pl_to_en([text_pl])[0]
aspects = extract_aspects(text_en)
result = []
for asp in aspects:
input_text = f"{text_en} [SEP] {asp}"
inputs = sentiment_tokenizer(input_text, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
logits = sentiment_model(**inputs).logits
predicted_class_id = int(logits.argmax().cpu())
sentiment_label = {0: "negatywny", 1: "neutralny", 2: "pozytywny", 3: "konfliktowy"}[predicted_class_id]
asp_pl = aspect_aliases.get(asp, translate_en_to_pl([asp])[0].lower())
result.append(AspectSentiment(aspect=asp_pl, sentiment=sentiment_label))
return {"results": result}