asmashayea commited on
Commit
7ebac28
·
1 Parent(s): 6d91ffe
Files changed (2) hide show
  1. app.py +1 -1
  2. inference.py +30 -33
app.py CHANGED
@@ -12,7 +12,7 @@ demo = gr.Interface(
12
  fn=run_absa,
13
  inputs=[
14
  gr.Textbox(label="Arabic Review"),
15
- gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Choose Model", value="mT5")
16
  ],
17
  outputs=gr.Textbox(label="Extracted Aspect-Sentiment-Opinion Triplets"),
18
  title="Arabic ABSA (Aspect-Based Sentiment Analysis)",
 
12
  fn=run_absa,
13
  inputs=[
14
  gr.Textbox(label="Arabic Review"),
15
+ gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Choose Model", value="Araberta")
16
  ],
17
  outputs=gr.Textbox(label="Extracted Aspect-Sentiment-Opinion Triplets"),
18
  title="Arabic ABSA (Aspect-Based Sentiment Analysis)",
inference.py CHANGED
@@ -8,7 +8,6 @@ from huggingface_hub import hf_hub_download
8
 
9
  # Define supported models and their adapter IDs
10
  MODEL_OPTIONS = {
11
-
12
  "Araberta": {
13
  "base": "asmashayea/absa-araberta",
14
  "adapter": "asmashayea/absa-araberta"
@@ -22,31 +21,39 @@ MODEL_OPTIONS = {
22
  "adapter": "asmashayea/mbart-absa"
23
  },
24
  "GPT3.5": {
25
- "base": "bigscience/bloom-560m", # example, not ideal for ABSA
26
  "adapter": "asmashayea/gpt-absa"
27
  },
28
  "GPT4o": {
29
- "base": "bigscience/bloom-560m", # example, not ideal for ABSA
30
  "adapter": "asmashayea/gpt-absa"
31
  }
32
  }
33
 
34
  cached_models = {}
35
 
 
36
  def load_araberta():
37
  path = "asmashayea/absa-arabert"
 
38
 
39
  tokenizer = AutoTokenizer.from_pretrained(path)
40
  base_model = AutoModel.from_pretrained(path)
 
 
41
  lora_config = LoraConfig.from_pretrained(path)
42
  lora_model = get_peft_model(base_model, lora_config)
43
- local_pt = hf_hub_download(repo_id="asmashayea/absa-arabert", filename="bilstm_crf_head.pt")
44
 
 
 
45
 
46
  config = AutoConfig.from_pretrained(path)
47
  model = BERT_BiLSTM_CRF(lora_model, config)
48
- model.load_state_dict(torch.load(local_pt))
49
- model.eval()
 
 
 
50
 
51
  cached_models["Araberta"] = (tokenizer, model)
52
  return tokenizer, model
@@ -58,10 +65,15 @@ def infer_araberta(text):
58
  else:
59
  tokenizer, model = cached_models["Araberta"]
60
 
61
-
62
  device = next(model.parameters()).device
63
 
64
- inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
 
 
 
 
 
 
65
  input_ids = inputs['input_ids'].to(device)
66
  attention_mask = inputs['attention_mask'].to(device)
67
 
@@ -75,15 +87,13 @@ def infer_araberta(text):
75
  clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
76
  clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]
77
 
78
- # Horizontal output
79
  pairs = [f"{token}: {label}" for token, label in zip(clean_tokens, clean_labels)]
80
  horizontal_output = " | ".join(pairs)
81
 
82
- # Group by aspect span
83
  aspects = []
84
- current_tokens = []
85
- current_sentiment = None
86
-
87
  for token, label in zip(clean_tokens, clean_labels):
88
  if label.startswith("B-"):
89
  if current_tokens:
@@ -101,9 +111,7 @@ def infer_araberta(text):
101
  "aspect": " ".join(current_tokens).replace("##", ""),
102
  "sentiment": current_sentiment
103
  })
104
- current_tokens = []
105
- current_sentiment = None
106
-
107
  if current_tokens:
108
  aspects.append({
109
  "aspect": " ".join(current_tokens).replace("##", ""),
@@ -116,7 +124,6 @@ def infer_araberta(text):
116
  }
117
 
118
 
119
-
120
  def load_model(model_key):
121
  if model_key in cached_models:
122
  return cached_models[model_key]
@@ -124,34 +131,24 @@ def load_model(model_key):
124
  base_id = MODEL_OPTIONS[model_key]["base"]
125
  adapter_id = MODEL_OPTIONS[model_key]["adapter"]
126
 
 
 
127
  tokenizer = AutoTokenizer.from_pretrained(adapter_id)
128
- base_model = AutoModelForSeq2SeqLM.from_pretrained(base_id)
129
- model = PeftModel.from_pretrained(base_model, adapter_id)
130
  model.eval()
131
 
132
  cached_models[model_key] = (tokenizer, model)
133
  return tokenizer, model
134
 
135
 
136
-
137
-
138
  def predict_absa(text, model_choice):
139
-
140
-
141
  if model_choice in ['mT5', 'mBART']:
142
  tokenizer, model = load_model(model_choice)
143
  decoded = infer_t5_prompt(text, tokenizer, model)
144
-
145
  elif model_choice == 'Araberta':
146
-
147
  decoded = infer_araberta(text)
 
 
148
 
149
-
150
- # prompt = f"استخرج الجوانب والآراء والمشاعر من النص التالي:\n{text}"
151
- # inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
152
-
153
- # with torch.no_grad():
154
- # outputs = model.generate(**inputs, max_new_tokens=128)
155
-
156
- # decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
157
  return decoded
 
8
 
9
  # Define supported models and their adapter IDs
10
  MODEL_OPTIONS = {
 
11
  "Araberta": {
12
  "base": "asmashayea/absa-araberta",
13
  "adapter": "asmashayea/absa-araberta"
 
21
  "adapter": "asmashayea/mbart-absa"
22
  },
23
  "GPT3.5": {
24
+ "base": "bigscience/bloom-560m", # placeholder
25
  "adapter": "asmashayea/gpt-absa"
26
  },
27
  "GPT4o": {
28
+ "base": "bigscience/bloom-560m", # placeholder
29
  "adapter": "asmashayea/gpt-absa"
30
  }
31
  }
32
 
33
  cached_models = {}
34
 
35
+
36
  def load_araberta():
37
  path = "asmashayea/absa-arabert"
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
 
40
  tokenizer = AutoTokenizer.from_pretrained(path)
41
  base_model = AutoModel.from_pretrained(path)
42
+
43
+ # Load LoRA adapter
44
  lora_config = LoraConfig.from_pretrained(path)
45
  lora_model = get_peft_model(base_model, lora_config)
 
46
 
47
+ # Download CRF head from Hub
48
+ local_pt = hf_hub_download(repo_id=path, filename="bilstm_crf_head.pt")
49
 
50
  config = AutoConfig.from_pretrained(path)
51
  model = BERT_BiLSTM_CRF(lora_model, config)
52
+
53
+ # Always map to current device
54
+ state_dict = torch.load(local_pt, map_location=torch.device(device))
55
+ model.load_state_dict(state_dict)
56
+ model.to(device).eval()
57
 
58
  cached_models["Araberta"] = (tokenizer, model)
59
  return tokenizer, model
 
65
  else:
66
  tokenizer, model = cached_models["Araberta"]
67
 
 
68
  device = next(model.parameters()).device
69
 
70
+ inputs = tokenizer(
71
+ text,
72
+ return_tensors='pt',
73
+ truncation=True,
74
+ padding='max_length',
75
+ max_length=128
76
+ )
77
  input_ids = inputs['input_ids'].to(device)
78
  attention_mask = inputs['attention_mask'].to(device)
79
 
 
87
  clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
88
  clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]
89
 
90
+ # Horizontal token:label pairs
91
  pairs = [f"{token}: {label}" for token, label in zip(clean_tokens, clean_labels)]
92
  horizontal_output = " | ".join(pairs)
93
 
94
+ # Group into aspect spans
95
  aspects = []
96
+ current_tokens, current_sentiment = [], None
 
 
97
  for token, label in zip(clean_tokens, clean_labels):
98
  if label.startswith("B-"):
99
  if current_tokens:
 
111
  "aspect": " ".join(current_tokens).replace("##", ""),
112
  "sentiment": current_sentiment
113
  })
114
+ current_tokens, current_sentiment = [], None
 
 
115
  if current_tokens:
116
  aspects.append({
117
  "aspect": " ".join(current_tokens).replace("##", ""),
 
124
  }
125
 
126
 
 
127
  def load_model(model_key):
128
  if model_key in cached_models:
129
  return cached_models[model_key]
 
131
  base_id = MODEL_OPTIONS[model_key]["base"]
132
  adapter_id = MODEL_OPTIONS[model_key]["adapter"]
133
 
134
+ device = "cuda" if torch.cuda.is_available() else "cpu"
135
+
136
  tokenizer = AutoTokenizer.from_pretrained(adapter_id)
137
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(base_id).to(device)
138
+ model = PeftModel.from_pretrained(base_model, adapter_id).to(device)
139
  model.eval()
140
 
141
  cached_models[model_key] = (tokenizer, model)
142
  return tokenizer, model
143
 
144
 
 
 
145
  def predict_absa(text, model_choice):
 
 
146
  if model_choice in ['mT5', 'mBART']:
147
  tokenizer, model = load_model(model_choice)
148
  decoded = infer_t5_prompt(text, tokenizer, model)
 
149
  elif model_choice == 'Araberta':
 
150
  decoded = infer_araberta(text)
151
+ else:
152
+ decoded = {"error": f"Model {model_choice} not supported"}
153
 
 
 
 
 
 
 
 
 
154
  return decoded