File size: 1,919 Bytes
c5bcbe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
os.environ["USE_TF"] = "0"
import torch
import gradio as gr
from transformers import DebertaV2Tokenizer
from src.model import SentiNetTransformer
from src.config import HPARAMS
from src.ui import build_demo

# CONFIGUARATION
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hp = HPARAMS()
backbone_config_path = "model/config.json"
checkpint_path = "model/SentiNet_Transformer_params.pt"
tokenizer_path = "model/"

# LOAD MODEL & TOKENIZER
model = SentiNetTransformer(model_path=backbone_config_path, fc_dropout=hp.transformer_fc_dropout).to(device)
state_dict = torch.load(checkpint_path, map_location=device, weights_only=True)
model.load_state_dict(state_dict)
tokenizer = DebertaV2Tokenizer.from_pretrained(tokenizer_path)

# INFERENCE FUNCTION
@torch.no_grad()
def sentiment_classifier(model, tokenizer, text, thresh=0.5, max_length_trun=256, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    inputs = tokenizer(text, return_tensors="pt",
                       add_special_tokens=True, max_length=max_length_trun,
                       truncation=True, padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    model.eval()
    logits = model(inputs)
    prob = torch.sigmoid(logits).cpu().numpy()[0][0]

    if prob >= thresh:
        return "😀 Positive", round(float(prob), 3)
    else:
        return "😞 Negative", round(float(prob), 3)

# GRADIO DEMO
def generation_fn(text):
    return sentiment_classifier(model, tokenizer, text, max_length_trun=256, device=None)
    
demo = build_demo(
    generation_fn,
    english_title = "# SentiNet: Transformer‑Based Sentiment Classifier",
    persian_title = "# سنتی‌نت: تحلیل احساسات با ترنسفورمر",
    assets_dir = "assets",
    app_title = "SentiNet"
)

if __name__ == "__main__":
    demo.launch()