amiguel commited on
Commit
104349c
Β·
verified Β·
1 Parent(s): a1022aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -26
app.py CHANGED
@@ -35,14 +35,28 @@ HF_TOKEN = os.getenv("HF_TOKEN")
35
  # Model name
36
  MODEL_NAME = "amiguel/class_insp_program"
37
 
38
- # Label mapping
 
 
 
39
  LABEL_TO_CLASS = {
40
- 0: "Campaign", 1: "Corrosion Monitoring", 2: "Flare Tip", 3: "Flare TIP",
41
- 4: "FU Items", 5: "Intelligent Pigging", 6: "Lifting", 7: "Non Structural Tank",
42
- 8: "Piping", 9: "Pressure Safety Device", 10: "Pressure Vessel (VIE)",
43
- 11: "Pressure Vessel (VII)", 12: "Structure", 13: "Flame Arrestor"
 
 
 
 
 
 
 
 
 
44
  }
45
 
 
 
46
  # Required columns - UPDATED
47
  REQUIRED_COLS = ["MaintItem text", "Functional Loc.", "Description"]
48
 
@@ -61,6 +75,11 @@ with st.sidebar:
61
  type=["xlsx", "csv"],
62
  label_visibility="collapsed"
63
  )
 
 
 
 
 
64
 
65
  # Initialize session state
66
  if "messages" not in st.session_state:
@@ -105,7 +124,7 @@ def process_file(uploaded_file, _cache_key):
105
  return None
106
 
107
 
108
- # Model loading function
109
  @st.cache_resource
110
  def load_model(hf_token):
111
  if not TRANSFORMERS_AVAILABLE:
@@ -114,41 +133,108 @@ def load_model(hf_token):
114
  if not hf_token:
115
  st.error("πŸ” Please set the HF_TOKEN environment variable.")
116
  return None
 
117
  login(token=hf_token)
 
 
118
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
119
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(LABEL_TO_CLASS), token=hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  device = "cuda" if torch.cuda.is_available() else "cpu"
121
  model.to(device)
 
 
 
 
 
122
  return model, tokenizer
 
123
  except Exception as e:
124
  st.error(f"πŸ€– Model loading failed: {str(e)}")
 
 
125
  return None
126
 
127
 
128
- # Classification function
129
- def classify_instruction(prompt, context, model, tokenizer):
130
  model.eval()
131
  device = model.device
132
 
133
  if isinstance(context, pd.DataFrame):
134
  predictions = []
135
- for text in context["input_text"]:
136
- full_prompt = f"Context:\n{text}\n\nInstruction: {prompt}"
137
- inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
139
  with torch.no_grad():
140
  outputs = model(**inputs)
141
- prediction = outputs.logits.argmax().item()
142
- predictions.append(LABEL_TO_CLASS[prediction])
 
 
 
 
 
 
 
 
 
 
 
 
143
  return predictions
144
  else:
145
- full_prompt = f"Context:\n{context}\n\nInstruction: {prompt}"
146
- inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
 
147
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
148
  with torch.no_grad():
149
  outputs = model(**inputs)
 
150
  prediction = outputs.logits.argmax().item()
151
- return LABEL_TO_CLASS[prediction]
 
 
 
 
 
 
152
 
153
 
154
  # Excel download function - inserts Item Class before MaintItem text
@@ -162,14 +248,18 @@ def get_excel_download_link(df, filename="predicted_classes.xlsx"):
162
  cols = list(output_df.columns)
163
  if "Item Class" in cols:
164
  cols.remove("Item Class")
 
 
165
 
166
  # Find MaintItem text position
167
  if "MaintItem text" in cols:
168
  maint_idx = cols.index("MaintItem text")
169
- # Insert Item Class before MaintItem text
 
170
  cols.insert(maint_idx, "Item Class")
171
  else:
172
  # Fallback: put at beginning
 
173
  cols.insert(0, "Item Class")
174
 
175
  # Remove input_text column if present (internal use only)
@@ -198,9 +288,12 @@ def get_csv_download_link(df, filename="predicted_classes.csv"):
198
  cols = list(output_df.columns)
199
  if "Item Class" in cols and "MaintItem text" in cols:
200
  cols.remove("Item Class")
 
 
201
  maint_idx = cols.index("MaintItem text")
 
202
  cols.insert(maint_idx, "Item Class")
203
- output_df = output_df[cols]
204
 
205
  csv = output_df.to_csv(index=False)
206
  b64 = base64.b64encode(csv.encode()).decode()
@@ -261,15 +354,18 @@ if prompt := st.chat_input("Ask your inspection question..."):
261
  file_data = st.session_state.file_data
262
  if file_data["type"] == "table":
263
  with st.spinner("Classifying..."):
264
- predictions = classify_instruction(prompt, file_data["content"], model, tokenizer)
 
 
265
 
266
  # Add predictions to dataframe
267
  result_df = file_data["content"].copy()
268
  result_df["Item Class"] = predictions
 
269
 
270
  # Display preview (first 10 rows)
271
  st.write("**Predicted Item Classes (preview):**")
272
- display_cols = ["Item Class"] + REQUIRED_COLS
273
  st.dataframe(result_df[display_cols].head(10), use_container_width=True)
274
 
275
  # Stats
@@ -277,6 +373,10 @@ if prompt := st.chat_input("Ask your inspection question..."):
277
  st.write("**Class distribution:**")
278
  st.write(result_df["Item Class"].value_counts())
279
 
 
 
 
 
280
  # Download links
281
  st.markdown("---")
282
  col1, col2 = st.columns(2)
@@ -287,15 +387,21 @@ if prompt := st.chat_input("Ask your inspection question..."):
287
 
288
  response = f"βœ… Classification completed for {len(predictions)} rows."
289
  else:
290
- predicted_class = classify_instruction(prompt, file_data["content"], model, tokenizer)
291
- response = f"The Item Class is: **{predicted_class}**"
 
 
292
  else:
293
- predicted_class = classify_instruction(prompt, "", model, tokenizer)
294
- response = f"The Item Class is: **{predicted_class}**"
 
 
295
 
296
  st.markdown(response)
297
  st.session_state.messages.append({"role": "assistant", "content": response})
298
  except Exception as e:
299
  st.error(f"⚑ Classification error: {str(e)}")
 
 
300
  else:
301
- st.error("πŸ€– Model not loaded!")
 
35
  # Model name
36
  MODEL_NAME = "amiguel/class_insp_program"
37
 
38
+ # =============================================================================
39
+ # FIXED: Label mapping must match EXACTLY what the model was trained with
40
+ # The model was trained with 13 classes (Flare Tip and Flare TIP were merged)
41
+ # =============================================================================
42
  LABEL_TO_CLASS = {
43
+ 0: "Campaign",
44
+ 1: "Corrosion Monitoring",
45
+ 2: "Flare Tip", # This now covers both "Flare Tip" and "Flare TIP"
46
+ 3: "FU Items",
47
+ 4: "Intelligent Pigging",
48
+ 5: "Lifting",
49
+ 6: "Non Structural Tank",
50
+ 7: "Piping",
51
+ 8: "Pressure Safety Device",
52
+ 9: "Pressure Vessel (VIE)",
53
+ 10: "Pressure Vessel (VII)",
54
+ 11: "Structure",
55
+ 12: "Flame Arrestor"
56
  }
57
 
58
+ NUM_LABELS = len(LABEL_TO_CLASS) # Should be 13
59
+
60
  # Required columns - UPDATED
61
  REQUIRED_COLS = ["MaintItem text", "Functional Loc.", "Description"]
62
 
 
75
  type=["xlsx", "csv"],
76
  label_visibility="collapsed"
77
  )
78
+
79
+ # Show model info
80
+ st.markdown("---")
81
+ st.markdown(f"**Model:** `{MODEL_NAME}`")
82
+ st.markdown(f"**Classes:** {NUM_LABELS}")
83
 
84
  # Initialize session state
85
  if "messages" not in st.session_state:
 
124
  return None
125
 
126
 
127
+ # Model loading function - FIXED
128
  @st.cache_resource
129
  def load_model(hf_token):
130
  if not TRANSFORMERS_AVAILABLE:
 
133
  if not hf_token:
134
  st.error("πŸ” Please set the HF_TOKEN environment variable.")
135
  return None
136
+
137
  login(token=hf_token)
138
+
139
+ # Load tokenizer
140
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
141
+
142
+ # =================================================================
143
+ # FIXED: Load model WITHOUT specifying num_labels
144
+ # Let it auto-detect from config.json, or use ignore_mismatched_sizes
145
+ # =================================================================
146
+ try:
147
+ # First try: Load without specifying num_labels (uses config.json)
148
+ model = AutoModelForSequenceClassification.from_pretrained(
149
+ MODEL_NAME,
150
+ token=hf_token
151
+ )
152
+ except Exception as e1:
153
+ # Fallback: Try with explicit num_labels and ignore size mismatch
154
+ st.warning(f"Auto-load failed, trying with explicit config: {str(e1)}")
155
+ model = AutoModelForSequenceClassification.from_pretrained(
156
+ MODEL_NAME,
157
+ num_labels=NUM_LABELS,
158
+ token=hf_token,
159
+ ignore_mismatched_sizes=True # This allows loading even if sizes differ
160
+ )
161
+
162
  device = "cuda" if torch.cuda.is_available() else "cpu"
163
  model.to(device)
164
+ model.eval()
165
+
166
+ # Log successful load
167
+ st.sidebar.success(f"βœ… Model loaded on {device}")
168
+
169
  return model, tokenizer
170
+
171
  except Exception as e:
172
  st.error(f"πŸ€– Model loading failed: {str(e)}")
173
+ import traceback
174
+ st.error(f"Full traceback:\n```\n{traceback.format_exc()}\n```")
175
  return None
176
 
177
 
178
+ # Classification function - IMPROVED with confidence scores
179
+ def classify_instruction(prompt, context, model, tokenizer, return_confidence=False):
180
  model.eval()
181
  device = model.device
182
 
183
  if isinstance(context, pd.DataFrame):
184
  predictions = []
185
+ confidences = []
186
+
187
+ # Process in batches for efficiency
188
+ batch_size = 32
189
+ texts = context["input_text"].tolist()
190
+
191
+ for i in range(0, len(texts), batch_size):
192
+ batch_texts = texts[i:i + batch_size]
193
+
194
+ # Prepare inputs
195
+ inputs = tokenizer(
196
+ batch_texts,
197
+ return_tensors="pt",
198
+ padding=True,
199
+ truncation=True,
200
+ max_length=128
201
+ )
202
  inputs = {k: v.to(device) for k, v in inputs.items()}
203
+
204
  with torch.no_grad():
205
  outputs = model(**inputs)
206
+ probs = torch.softmax(outputs.logits, dim=-1)
207
+ batch_preds = outputs.logits.argmax(dim=-1).cpu().numpy()
208
+ batch_confs = probs.max(dim=-1).values.cpu().numpy()
209
+
210
+ for pred, conf in zip(batch_preds, batch_confs):
211
+ # Handle case where prediction ID exceeds our mapping
212
+ if pred in LABEL_TO_CLASS:
213
+ predictions.append(LABEL_TO_CLASS[pred])
214
+ else:
215
+ predictions.append(f"Unknown ({pred})")
216
+ confidences.append(float(conf))
217
+
218
+ if return_confidence:
219
+ return predictions, confidences
220
  return predictions
221
  else:
222
+ # Single text classification
223
+ text = str(context) if context else prompt
224
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
225
  inputs = {k: v.to(device) for k, v in inputs.items()}
226
+
227
  with torch.no_grad():
228
  outputs = model(**inputs)
229
+ probs = torch.softmax(outputs.logits, dim=-1)
230
  prediction = outputs.logits.argmax().item()
231
+ confidence = probs[0, prediction].item()
232
+
233
+ pred_label = LABEL_TO_CLASS.get(prediction, f"Unknown ({prediction})")
234
+
235
+ if return_confidence:
236
+ return pred_label, confidence
237
+ return pred_label
238
 
239
 
240
  # Excel download function - inserts Item Class before MaintItem text
 
248
  cols = list(output_df.columns)
249
  if "Item Class" in cols:
250
  cols.remove("Item Class")
251
+ if "Confidence" in cols:
252
+ cols.remove("Confidence")
253
 
254
  # Find MaintItem text position
255
  if "MaintItem text" in cols:
256
  maint_idx = cols.index("MaintItem text")
257
+ # Insert Item Class and Confidence before MaintItem text
258
+ cols.insert(maint_idx, "Confidence")
259
  cols.insert(maint_idx, "Item Class")
260
  else:
261
  # Fallback: put at beginning
262
+ cols.insert(0, "Confidence")
263
  cols.insert(0, "Item Class")
264
 
265
  # Remove input_text column if present (internal use only)
 
288
  cols = list(output_df.columns)
289
  if "Item Class" in cols and "MaintItem text" in cols:
290
  cols.remove("Item Class")
291
+ if "Confidence" in cols:
292
+ cols.remove("Confidence")
293
  maint_idx = cols.index("MaintItem text")
294
+ cols.insert(maint_idx, "Confidence")
295
  cols.insert(maint_idx, "Item Class")
296
+ output_df = output_df[[c for c in cols if c in output_df.columns]]
297
 
298
  csv = output_df.to_csv(index=False)
299
  b64 = base64.b64encode(csv.encode()).decode()
 
354
  file_data = st.session_state.file_data
355
  if file_data["type"] == "table":
356
  with st.spinner("Classifying..."):
357
+ predictions, confidences = classify_instruction(
358
+ prompt, file_data["content"], model, tokenizer, return_confidence=True
359
+ )
360
 
361
  # Add predictions to dataframe
362
  result_df = file_data["content"].copy()
363
  result_df["Item Class"] = predictions
364
+ result_df["Confidence"] = [f"{c:.2%}" for c in confidences]
365
 
366
  # Display preview (first 10 rows)
367
  st.write("**Predicted Item Classes (preview):**")
368
+ display_cols = ["Item Class", "Confidence"] + REQUIRED_COLS
369
  st.dataframe(result_df[display_cols].head(10), use_container_width=True)
370
 
371
  # Stats
 
373
  st.write("**Class distribution:**")
374
  st.write(result_df["Item Class"].value_counts())
375
 
376
+ # Average confidence
377
+ avg_conf = sum(confidences) / len(confidences)
378
+ st.write(f"**Average confidence:** {avg_conf:.2%}")
379
+
380
  # Download links
381
  st.markdown("---")
382
  col1, col2 = st.columns(2)
 
387
 
388
  response = f"βœ… Classification completed for {len(predictions)} rows."
389
  else:
390
+ predicted_class, confidence = classify_instruction(
391
+ prompt, file_data["content"], model, tokenizer, return_confidence=True
392
+ )
393
+ response = f"The Item Class is: **{predicted_class}** (confidence: {confidence:.2%})"
394
  else:
395
+ predicted_class, confidence = classify_instruction(
396
+ prompt, "", model, tokenizer, return_confidence=True
397
+ )
398
+ response = f"The Item Class is: **{predicted_class}** (confidence: {confidence:.2%})"
399
 
400
  st.markdown(response)
401
  st.session_state.messages.append({"role": "assistant", "content": response})
402
  except Exception as e:
403
  st.error(f"⚑ Classification error: {str(e)}")
404
+ import traceback
405
+ st.error(f"```\n{traceback.format_exc()}\n```")
406
  else:
407
+ st.error("πŸ€– Model not loaded!")