asmashayea commited on
Commit
6d91ffe
·
1 Parent(s): 4308dad

Add application file

Browse files
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference import predict_absa, MODEL_OPTIONS
3
+
4
+ def run_absa(review, model_choice):
5
+ try:
6
+ return predict_absa(review, model_choice)
7
+ except Exception as e:
8
+ return f"❌ Error: {str(e)}"
9
+
10
+
11
+ demo = gr.Interface(
12
+ fn=run_absa,
13
+ inputs=[
14
+ gr.Textbox(label="Arabic Review"),
15
+ gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Choose Model", value="mT5")
16
+ ],
17
+ outputs=gr.Textbox(label="Extracted Aspect-Sentiment-Opinion Triplets"),
18
+ title="Arabic ABSA (Aspect-Based Sentiment Analysis)",
19
+ description="Choose a model (Araberta, mT5, GPT) to extract aspects, opinions, and sentiment using LoRA adapters"
20
+ )
21
+
22
+ if __name__ == "__main__":
23
+ demo.launch()
araberta_setting/modeling_bilstm_crf.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchcrf import CRF
4
+
5
+ class BERT_BiLSTM_CRF(nn.Module):
6
+ def __init__(self, base_model, config, dropout_rate=0.2, rnn_dim=256):
7
+ super().__init__()
8
+ self.bert = base_model
9
+ self.label2id = config.label2id # <-- pulled from config
10
+ self.id2label = config.id2label
11
+ self.num_labels = config.num_labels
12
+
13
+ self.bilstm = nn.LSTM(
14
+ self.bert.config.hidden_size,
15
+ rnn_dim,
16
+ num_layers=2,
17
+ batch_first=True,
18
+ bidirectional=True,
19
+ dropout=0.2
20
+ )
21
+ self.dropout = nn.Dropout(dropout_rate)
22
+ self.classifier = nn.Linear(rnn_dim * 2, self.num_labels)
23
+ self.crf = CRF(self.num_labels, batch_first=True)
24
+
25
+ def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
26
+ outputs = self.bert(
27
+ input_ids=input_ids,
28
+ attention_mask=attention_mask,
29
+ token_type_ids=token_type_ids
30
+ )
31
+ lstm_out, _ = self.bilstm(self.dropout(outputs.last_hidden_state))
32
+ emissions = self.classifier(lstm_out)
33
+ mask = attention_mask.bool()
34
+
35
+ if labels is not None:
36
+ safe_labels = labels.clone()
37
+ safe_labels[labels == -100] = self.label2id['O']
38
+ loss = -self.crf(emissions, safe_labels, mask=mask, reduction='mean')
39
+ return {'loss': loss, 'logits': emissions}
40
+ else:
41
+ decoded = self.crf.decode(emissions, mask=mask)
42
+ max_len = input_ids.shape[1]
43
+ padded_decoded = [seq + [0] * (max_len - len(seq)) for seq in decoded]
44
+ logits = torch.tensor(padded_decoded, device=input_ids.device)
45
+ return {'logits': logits}
inference.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
4
+ from peft import LoraConfig, get_peft_model, PeftModel
5
+ from araberta_setting.modeling_bilstm_crf import BERT_BiLSTM_CRF
6
+ from seq2seq_inference import infer_t5_prompt
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ # Define supported models and their adapter IDs
10
+ MODEL_OPTIONS = {
11
+
12
+ "Araberta": {
13
+ "base": "asmashayea/absa-araberta",
14
+ "adapter": "asmashayea/absa-araberta"
15
+ },
16
+ "mT5": {
17
+ "base": "google/mt5-base",
18
+ "adapter": "asmashayea/mt4-absa"
19
+ },
20
+ "mBART": {
21
+ "base": "facebook/mbart-large-50-many-to-many-mmt",
22
+ "adapter": "asmashayea/mbart-absa"
23
+ },
24
+ "GPT3.5": {
25
+ "base": "bigscience/bloom-560m", # example, not ideal for ABSA
26
+ "adapter": "asmashayea/gpt-absa"
27
+ },
28
+ "GPT4o": {
29
+ "base": "bigscience/bloom-560m", # example, not ideal for ABSA
30
+ "adapter": "asmashayea/gpt-absa"
31
+ }
32
+ }
33
+
34
+ cached_models = {}
35
+
36
+ def load_araberta():
37
+ path = "asmashayea/absa-arabert"
38
+
39
+ tokenizer = AutoTokenizer.from_pretrained(path)
40
+ base_model = AutoModel.from_pretrained(path)
41
+ lora_config = LoraConfig.from_pretrained(path)
42
+ lora_model = get_peft_model(base_model, lora_config)
43
+ local_pt = hf_hub_download(repo_id="asmashayea/absa-arabert", filename="bilstm_crf_head.pt")
44
+
45
+
46
+ config = AutoConfig.from_pretrained(path)
47
+ model = BERT_BiLSTM_CRF(lora_model, config)
48
+ model.load_state_dict(torch.load(local_pt))
49
+ model.eval()
50
+
51
+ cached_models["Araberta"] = (tokenizer, model)
52
+ return tokenizer, model
53
+
54
+
55
+ def infer_araberta(text):
56
+ if "Araberta" not in cached_models:
57
+ tokenizer, model = load_araberta()
58
+ else:
59
+ tokenizer, model = cached_models["Araberta"]
60
+
61
+
62
+ device = next(model.parameters()).device
63
+
64
+ inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
65
+ input_ids = inputs['input_ids'].to(device)
66
+ attention_mask = inputs['attention_mask'].to(device)
67
+
68
+ with torch.no_grad():
69
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
70
+ predicted_ids = outputs['logits'][0].cpu().tolist()
71
+
72
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
73
+ predicted_labels = [model.config.id2label.get(p, 'O') for p in predicted_ids]
74
+
75
+ clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
76
+ clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]
77
+
78
+ # Horizontal output
79
+ pairs = [f"{token}: {label}" for token, label in zip(clean_tokens, clean_labels)]
80
+ horizontal_output = " | ".join(pairs)
81
+
82
+ # Group by aspect span
83
+ aspects = []
84
+ current_tokens = []
85
+ current_sentiment = None
86
+
87
+ for token, label in zip(clean_tokens, clean_labels):
88
+ if label.startswith("B-"):
89
+ if current_tokens:
90
+ aspects.append({
91
+ "aspect": " ".join(current_tokens).replace("##", ""),
92
+ "sentiment": current_sentiment
93
+ })
94
+ current_tokens = [token]
95
+ current_sentiment = label.split("-")[1]
96
+ elif label.startswith("I-") and current_sentiment == label.split("-")[1]:
97
+ current_tokens.append(token)
98
+ else:
99
+ if current_tokens:
100
+ aspects.append({
101
+ "aspect": " ".join(current_tokens).replace("##", ""),
102
+ "sentiment": current_sentiment
103
+ })
104
+ current_tokens = []
105
+ current_sentiment = None
106
+
107
+ if current_tokens:
108
+ aspects.append({
109
+ "aspect": " ".join(current_tokens).replace("##", ""),
110
+ "sentiment": current_sentiment
111
+ })
112
+
113
+ return {
114
+ "token_predictions": horizontal_output,
115
+ "aspects": aspects
116
+ }
117
+
118
+
119
+
120
+ def load_model(model_key):
121
+ if model_key in cached_models:
122
+ return cached_models[model_key]
123
+
124
+ base_id = MODEL_OPTIONS[model_key]["base"]
125
+ adapter_id = MODEL_OPTIONS[model_key]["adapter"]
126
+
127
+ tokenizer = AutoTokenizer.from_pretrained(adapter_id)
128
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(base_id)
129
+ model = PeftModel.from_pretrained(base_model, adapter_id)
130
+ model.eval()
131
+
132
+ cached_models[model_key] = (tokenizer, model)
133
+ return tokenizer, model
134
+
135
+
136
+
137
+
138
+ def predict_absa(text, model_choice):
139
+
140
+
141
+ if model_choice in ['mT5', 'mBART']:
142
+ tokenizer, model = load_model(model_choice)
143
+ decoded = infer_t5_prompt(text, tokenizer, model)
144
+
145
+ elif model_choice == 'Araberta':
146
+
147
+ decoded = infer_araberta(text)
148
+
149
+
150
+ # prompt = f"استخرج الجوانب والآراء والمشاعر من النص التالي:\n{text}"
151
+ # inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
152
+
153
+ # with torch.no_grad():
154
+ # outputs = model.generate(**inputs, max_new_tokens=128)
155
+
156
+ # decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
157
+ return decoded
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ gradio
3
+ peft
4
+ torchcrf
5
+ torch
6
+ torchvision
7
+ torchaudio
8
+ pytorch-crf
9
+ sentencepiece
seq2seq_inference.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+ from peft import PeftModel
5
+
6
+ SYSTEM_PROMPT = (
7
+ "You are an advanced AI model specialized in extracting aspects and determining their sentiment polarity from customer reviews.\n\n"
8
+ "Instructions:\n"
9
+ "1. Extract only the aspects (nouns) mentioned in the review.\n"
10
+ "2. Assign a sentiment to each aspect: \"positive\", \"negative\", or \"neutral\".\n"
11
+ "3. Return aspects in the same language as they appear.\n"
12
+ "4. An aspect must be a noun that refers to a specific item or service the user described.\n"
13
+ "5. Ignore adjectives, general ideas, and vague topics.\n"
14
+ "6. Do NOT translate, explain, or add extra text.\n"
15
+ "7. The output must be just a valid JSON list with 'aspect' and 'sentiment'. Start with `[` and stop at `]`.\n"
16
+ "8. Do NOT output the instructions, review, or any text — only one output JSON list.\n"
17
+ "9. Just one output and one review."
18
+ )
19
+
20
+
21
+
22
+ def infer_t5_prompt(review_text, tokenizer, peft_model):
23
+ prompt = SYSTEM_PROMPT + f"\n\nReview: {review_text}"
24
+
25
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(peft_model.device)
26
+
27
+ with torch.no_grad():
28
+ outputs = peft_model.generate(
29
+ **inputs,
30
+ max_new_tokens=256,
31
+ num_beams=4,
32
+ do_sample=False,
33
+ temperature=0.0,
34
+ early_stopping=True,
35
+ pad_token_id=tokenizer.pad_token_id,
36
+ eos_token_id=tokenizer.eos_token_id,
37
+ )
38
+
39
+ decoded = tokenizer.decode(
40
+ outputs[0],
41
+ skip_special_tokens=True,
42
+ clean_up_tokenization_spaces=False
43
+ ).strip()
44
+
45
+ decoded = decoded.replace('<extra_id_0>', '').replace('</s>', '').strip()
46
+
47
+ try:
48
+ return json.loads(decoded)
49
+ except json.JSONDecodeError:
50
+ return decoded