Dohahemdann's picture
Upload inference.py with huggingface_hub
88a4aeb verified
import torch
import torch.nn.functional as F
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
def load_model(model_dir):
model_dir = Path(model_dir)
with open(model_dir / "model_config.json") as f:
cfg = json.load(f)
tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
# Rebuild heads
encoder = AutoModel.from_pretrained(str(model_dir))
hidden = cfg["hidden_dim"]
enc_dim = encoder.config.hidden_size
def make_head():
return nn.Sequential(
nn.Linear(enc_dim, hidden), nn.GELU(),
nn.LayerNorm(hidden), nn.Dropout(0.0),
nn.Linear(hidden, 3),
)
aspect_head = nn.Sequential(
nn.Linear(enc_dim, hidden), nn.GELU(),
nn.LayerNorm(hidden), nn.Dropout(0.0),
nn.Linear(hidden, 9),
)
aspect_sentiment_heads = nn.ModuleList([make_head() for _ in range(9)])
heads_state = torch.load(model_dir / "classification_heads.pt", map_location="cpu")
aspect_head.load_state_dict(heads_state["aspect_head"])
for i, h in enumerate(aspect_sentiment_heads):
h.load_state_dict(heads_state["aspect_sentiment_heads"][i])
return tokenizer, encoder, aspect_head, aspect_sentiment_heads, cfg
def predict(text, tokenizer, encoder, aspect_head, sentiment_heads, cfg, device="cpu"):
ASPECT_NAMES = cfg["aspect_names"]
NEUTRAL_MARGIN = cfg["neutral_margin"]
enc = tokenizer(text, return_tensors="pt",
max_length=256, padding="max_length", truncation=True)
with torch.no_grad():
out = encoder(**enc.to(device))
mask = enc["attention_mask"].unsqueeze(-1).float()
rep = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
asp_logits = aspect_head(rep)
sent_logits = torch.stack([h(rep) for h in sentiment_heads], dim=1)
asp_probs = torch.sigmoid(asp_logits[0])
threshold = cfg.get("best_threshold", 0.5)
detected = [ASPECT_NAMES[j] for j in range(9) if asp_probs[j] > threshold]
if not detected:
detected = [ASPECT_NAMES[asp_probs.argmax().item()]]
result = {}
for asp in detected:
j = ASPECT_NAMES.index(asp)
sp = F.softmax(sent_logits[0, j], dim=0)
pos, neg, neu = sp[0].item(), sp[1].item(), sp[2].item()
if neu >= max(pos, neg) + NEUTRAL_MARGIN:
sent = "neutral"
else:
sent = "positive" if pos >= neg else "negative"
result[asp] = sent
return {"aspects": detected, "sentiments": result}
if __name__ == "__main__":
tokenizer, encoder, aspect_head, sentiment_heads, cfg = load_model(".")
sample = "[AR] الأكل كان تمام والخدمة ممتازة لكن السعر غالي شوية"
out = predict(sample, tokenizer, encoder, aspect_head, sentiment_heads, cfg)
print(out)