Spaces:
Sleeping
Sleeping
File size: 4,004 Bytes
5ac897e b12e63e e69a3fb 5ac897e af2576d 5ac897e af2576d 5ac897e e69a3fb af2576d e69a3fb af2576d e69a3fb 5ac897e e69a3fb 5ac897e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from fastapi import FastAPI, Request
from pydantic import BaseModel
from typing import List
from transformers import (
AutoTokenizer, AutoModelForTokenClassification,
AutoModelForSequenceClassification, pipeline
)
import torch
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
aspect_tokenizer = AutoTokenizer.from_pretrained("EfektMotyla/bert-aspect-ner")
aspect_model = AutoModelForTokenClassification.from_pretrained("EfektMotyla/bert-aspect-ner").to(device)
sentiment_tokenizer = AutoTokenizer.from_pretrained("EfektMotyla/absa-roberta")
sentiment_model = AutoModelForSequenceClassification.from_pretrained("EfektMotyla/absa-roberta").to(device)
pl_to_en = pipeline(
"translation",
model="Helsinki-NLP/opus-mt-pl-en",
device=0 if device == "cuda" else -1
)
en_to_pl = pipeline(
"translation",
model="gsarti/opus-mt-tc-en-pl",
device=0 if device == "cuda" else -1
)
# === Dane wej艣ciowe i wyj艣ciowe ===
class Comment(BaseModel):
text: str
class AspectSentiment(BaseModel):
aspect: str
sentiment: str
class AnalysisResult(BaseModel):
results: List[AspectSentiment]
# === S艂ownik alias贸w aspekt贸w EN鈫扨L (taki sam jak wcze艣niej) ===
aspect_aliases = {
"food": "jedzenie", "service": "obs艂uga", "price": "cena",
"taste": "smak", "waiter": "obs艂uga", "dish": "danie",
"portion": "porcja", "staff": "obs艂uga", "decor": "wystr贸j",
"menu": "menu", "drink": "napoje", "location": "lokalizacja",
"time": "czas oczekiwania", "cleanliness": "czysto艣膰", "smell": "zapach",
"value": "cena", "experience": "do艣wiadczenie", "recommendation": "og贸lna ocena",
"children": "dzieci", "family": "rodzina", "pet": "zwierz臋ta"
# dodaj wi臋cej jak chcesz
}
# === Funkcje pomocnicze ===
def translate_pl_to_en(texts):
return [res["translation_text"] for res in pl_to_en(texts)]
def translate_en_to_pl(texts):
return [res["translation_text"] for res in en_to_pl(texts)]
def extract_aspects(text_en):
inputs = aspect_tokenizer(text_en, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = aspect_model(**inputs)
preds = torch.argmax(outputs.logits, dim=2)[0].cpu().numpy()
tokens = aspect_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
labels = [aspect_model.config.id2label[p] for p in preds]
aspects = []
current_tokens = []
for token, label in zip(tokens, labels):
if label == "B-ASP":
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
current_tokens = [token]
elif label == "I-ASP" and current_tokens:
current_tokens.append(token)
else:
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
current_tokens = []
if current_tokens:
aspects.append(aspect_tokenizer.convert_tokens_to_string(current_tokens).strip())
return list(set([a.lower() for a in aspects]))
# === G艂贸wna funkcja API ===
app = FastAPI()
@app.post("/analyze", response_model=AnalysisResult)
def analyze_comment(comment: Comment):
text_pl = comment.text
text_en = translate_pl_to_en([text_pl])[0]
aspects = extract_aspects(text_en)
result = []
for asp in aspects:
input_text = f"{text_en} [SEP] {asp}"
inputs = sentiment_tokenizer(input_text, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
logits = sentiment_model(**inputs).logits
predicted_class_id = int(logits.argmax().cpu())
sentiment_label = {0: "negatywny", 1: "neutralny", 2: "pozytywny", 3: "konfliktowy"}[predicted_class_id]
asp_pl = aspect_aliases.get(asp, translate_en_to_pl([asp])[0].lower())
result.append(AspectSentiment(aspect=asp_pl, sentiment=sentiment_label))
return {"results": result}
|