yassine123Z commited on
Commit
f820f5a
Β·
verified Β·
1 Parent(s): 54d63e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +454 -78
app.py CHANGED
@@ -16,133 +16,509 @@ from datetime import datetime
16
  app = FastAPI(title="Transaction Category Mapper")
17
 
18
  # ==================================================
19
- # 🧠 Load Model
20
  # ==================================================
21
- model_path = "HEN10/setfit-particular-transaction-solon-embeddings-labels-large-kaggle-automatisation-v1"
22
- model = SetFitModel.from_pretrained(model_path)
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
- model.to(device)
25
 
26
  # ==================================================
27
- # 🧩 Utility: Classify transaction text
28
  # ==================================================
29
- def classify_transaction(text, top_k=3):
30
- embeddings = model.model_body.encode([text], convert_to_tensor=True, device=device)
31
- label_embeddings = model.model_body.encode(model.labels, convert_to_tensor=True, device=device)
32
- cos_scores = util.pytorch_cos_sim(embeddings, label_embeddings)[0]
33
- top_results = torch.topk(cos_scores, k=top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  results = []
36
- for score, idx in zip(top_results.values, top_results.indices):
37
- cat = model.labels[idx]
38
- cat_split = cat.split(" > ")
39
- cat1 = cat_split[0] if len(cat_split) > 0 else ""
40
- cat2 = cat_split[1] if len(cat_split) > 1 else ""
41
- results.append({"cat1": cat1, "cat2": cat2, "score": float(score)})
 
42
  return results
43
 
 
 
 
 
 
 
 
 
 
 
 
44
  # ==================================================
45
- # πŸ“₯ Batch Mapping for Review (Simplified Columns)
46
  # ==================================================
47
  def map_csv_for_review(file):
48
- """Process CSV and return simplified results with editable prediction columns."""
49
  global batch_results
50
-
51
  try:
52
  if file is None:
53
  return None, "⚠️ Please upload a CSV file"
54
-
55
  df = pd.read_csv(file.name)
56
  if "transaction" not in df.columns:
57
- return None, "❌ Missing column 'transaction'"
58
-
59
  results = []
60
  for idx, text in enumerate(df["transaction"]):
61
  if pd.isna(text) or str(text).strip() == "":
62
  continue
63
-
64
- top_match = classify_transaction(str(text), top_k=1)[0]
 
 
 
65
  results.append({
66
  "ID": idx + 1,
67
- "Transaction": str(text)[:100],
68
- "Predicted_Cat1": top_match["cat1"],
69
- "Predicted_Cat2": top_match["cat2"],
70
- "Confidence": round(top_match["score"], 3),
71
- "Status": "βœ… OK" if top_match["score"] > 0.8 else "⚠️ Review"
72
  })
73
 
74
- batch_results = pd.DataFrame(results)
75
- return batch_results, f"βœ… Processed {len(batch_results)} transactions!"
76
 
 
 
 
 
 
77
  except Exception as e:
78
  return None, f"❌ Error: {str(e)}"
79
 
80
  # ==================================================
81
- # πŸ’Ύ Save Corrections & Generate Training Data
82
  # ==================================================
83
- def save_batch_corrections(updated_df):
84
- """Save corrected table and prepare new training data."""
85
- global batch_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- corrected = pd.DataFrame(updated_df)
88
- corrected.to_csv("corrected_results.csv", index=False)
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # Build training dataset (only corrected entries)
91
- training_data = corrected[["Transaction", "Predicted_Cat1", "Predicted_Cat2"]].copy()
92
- training_data = training_data.rename(columns={
93
- "Transaction": "text",
94
- "Predicted_Cat1": "label1",
95
- "Predicted_Cat2": "label2"
96
- })
97
- training_data.to_csv("training_data.csv", index=False)
98
 
99
- return (
100
- f"πŸ’Ύ Saved {len(corrected)} reviewed transactions.",
101
- "corrected_results.csv",
102
- "training_data.csv"
103
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # ==================================================
106
- # 🌍 Gradio Interface
107
  # ==================================================
108
- with gr.Blocks(title="Transaction Category Mapper") as demo:
109
- gr.Markdown("# πŸ’Ό Transaction Category Mapper")
110
- gr.Markdown("Upload a file β†’ review predictions β†’ correct directly β†’ save for fine-tuning.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  with gr.Tab("πŸ“‚ Batch Review & Correct"):
113
- upload_file = gr.File(label="Upload your CSV file (must include 'transaction' column)")
114
- btn_process = gr.Button("πŸš€ Process Transactions", variant="primary")
115
- review_table = gr.Dataframe(
116
- headers=["ID", "Transaction", "Predicted_Cat1", "Predicted_Cat2", "Confidence", "Status"],
 
 
 
 
 
 
 
 
 
 
 
117
  interactive=True,
118
- wrap=True,
119
- max_rows=50
120
  )
121
- process_status = gr.Textbox(label="Status", lines=3, interactive=False)
122
-
 
 
 
 
 
 
 
 
123
  btn_process.click(
124
- fn=map_csv_for_review,
125
- inputs=upload_file,
126
  outputs=[review_table, process_status]
127
  )
128
-
129
- gr.Markdown("---")
130
- gr.Markdown("### πŸ’Ύ Save and Export Results")
131
- btn_save = gr.Button("πŸ’Ύ Save All Corrections", variant="primary")
132
- btn_download_corrected = gr.File(label="πŸ“₯ Download Corrected Results")
133
- btn_download_training = gr.File(label="πŸ“₯ Download Training Data (Corrections)")
134
-
135
- save_status = gr.Textbox(label="Save Status", lines=3, interactive=False)
136
-
137
  btn_save.click(
138
  fn=save_batch_corrections,
139
  inputs=review_table,
140
  outputs=[save_status, btn_download_corrected, btn_download_training]
141
  )
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # ==================================================
144
- # πŸš€ Launch App
145
  # ==================================================
146
- if __name__ == "__main__":
147
- import uvicorn
148
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  app = FastAPI(title="Transaction Category Mapper")
17
 
18
  # ==================================================
19
+ # 🧠 Load Main Model
20
  # ==================================================
21
+ print("Loading main model...")
22
+ model = SetFitModel.from_pretrained("yassine123Z/EmissionFactor-mapper2-v2")
23
+ print("Model loaded successfully!")
 
24
 
25
  # ==================================================
26
+ # πŸ“˜ Reference Categories
27
  # ==================================================
28
+ ref_data = pd.DataFrame({
29
+ "Cat1EN": [
30
+ "Purchase of goods","Purchase of goods","Purchase of goods","Purchase of goods",
31
+ "Purchase of goods","Purchase of goods","Purchase of goods","Purchase of goods",
32
+ "Purchase of goods","Purchase of goods","Purchase of materials","Purchase of materials",
33
+ "Purchase of materials","Purchase of materials","Purchase of materials","Purchase of materials",
34
+ "Purchase of services","Purchase of services","Purchase of services","Purchase of services",
35
+ "Purchase of services","Purchase of services","Purchase of services","Purchase of services",
36
+ "Purchase of services","Purchase of services","Purchase of services","Purchase of services",
37
+ "Purchase of services","Purchase of services","Food & beverages","Food & beverages",
38
+ "Food & beverages","Food & beverages","Food & beverages","Food & beverages",
39
+ "Food & beverages","Food & beverages","Food & beverages","Food & beverages",
40
+ "Heating and air conditioning","Heating and air conditioning","Fuels","Fuels","Fuels","Fuels",
41
+ "Fuels","Fuels",
42
+ "Mobility (freight)","Mobility (freight)","Mobility (freight)","Mobility (freight)",
43
+ "Mobility (freight)",
44
+ "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)",
45
+ "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)","Mobility (passengers)",
46
+ "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)","Mobility (passengers)",
47
+ "Process and fugitive emissions","Process and fugitive emissions",
48
+ "Process and fugitive emissions",
49
+ "Waste treatment","Waste treatment","Waste treatment",
50
+ "Waste treatment","Waste treatment","Waste treatment","Waste treatment","Waste treatment",
51
+ "Waste treatment","Waste treatment","Waste treatment","Waste treatment",
52
+ "Use of electricity","Use of electricity","Use of electricity"
53
+ ],
54
+ "Cat2EN": [
55
+ "Sporting goods","Buildings","Office supplies","Water consumption",
56
+ "Household appliances","Electrical equipment","Machinery and equipment","Furniture",
57
+ "Textiles and clothing","Vehicles","Construction materials","Organic materials",
58
+ "Paper and cardboard","Plastics and rubber","Chemicals","Refrigerants and others",
59
+ "Equipment rental","Building rental","Furniture rental","Vehicle rental and maintenance",
60
+ "Information and cultural services","Catering services","Health services","Specialized craft services",
61
+ "Administrative / consulting services","Cleaning services","IT services","Logistics services",
62
+ "Marketing / advertising services","Technical services","Alcoholic beverages","Non-alcoholic beverages",
63
+ "Condiments","Desserts","Fruits and vegetables","Fats and oils","Prepared / cooked meals",
64
+ "Animal products","Cereal products","Dairy products","Heat and steam","Air conditioning and refrigeration",
65
+ "Fossil fuels","Mobile fossil fuels","Organic fuels","Gaseous fossil fuels","Liquid fossil fuels",
66
+ "Solid fossil fuels",
67
+ "Air transport","Ship transport","Truck transport","Combined transport",
68
+ "Train transport",
69
+ "Air transport","Coach / Urban bus","Ship transport","Combined transport",
70
+ "E-Bike","Accommodation / Events","Soft mobility","Motorcycle / Scooter","Train transport",
71
+ "Public transport","Car",
72
+ "Agriculture","Global warming potential","Industrial processes",
73
+ "Commercial and industrial","Wastewater","Electrical equipment","Households and similar",
74
+ "Metal","Organic materials","Paper and cardboard","Batteries and accumulators","Plastics",
75
+ "Fugitive process emissions","Textiles","Glass",
76
+ "Electricity for electric vehicles","Renewables","Standard"
77
+ ],
78
+ "DescriptionCat2EN": [
79
+ "Goods purchase - sports","Goods purchase - buildings","Goods purchase - office items","Goods purchase - water",
80
+ "Goods purchase - appliances","Goods purchase - electricals","Goods purchase - machinery","Goods purchase - furniture",
81
+ "Goods purchase - textiles","Goods purchase - vehicles","Material purchase - construction","Material purchase - organic",
82
+ "Material purchase - paper","Material purchase - plastics","Material purchase - chemicals","Material purchase - refrigerants",
83
+ "Service - equipment rental","Service - building rental","Service - furniture rental","Service - vehicles",
84
+ "Service - info/culture","Service - catering","Service - healthcare","Service - crafts",
85
+ "Service - admin/consulting","Service - cleaning","Service - IT","Service - logistics",
86
+ "Service - marketing","Service - technical","Beverages - alcoholic","Beverages - non-alcoholic",
87
+ "Food condiments","Food desserts","Food fruits & vegetables","Food fats & oils","Prepared meals",
88
+ "Animal-based food","Cereal-based food","Dairy products","Heating - heat & steam","Heating - cooling/refrigeration",
89
+ "Fuel - fossil","Fuel - mobile fossil","Fuel - organic","Fuel - gaseous","Fuel - liquid","Fuel - solid",
90
+ "Freight transport - air","Freight transport - ship","Freight transport - truck","Freight transport - combined",
91
+ "Freight transport - train",
92
+ "Passenger transport - air","Passenger transport - bus","Passenger transport - ship",
93
+ "Passenger transport - combined","Passenger transport - e-bike","Passenger transport - accommodation/events",
94
+ "Passenger transport - soft mobility","Passenger transport - scooter/motorbike","Passenger transport - train",
95
+ "Passenger transport - public","Passenger transport - car",
96
+ "Emissions - agriculture","Emissions - warming potential",
97
+ "Emissions - industry",
98
+ "Waste - commercial/industrial","Waste - wastewater","Waste - electricals",
99
+ "Waste - households","Waste - metals","Waste - organics","Waste - paper","Waste - batteries",
100
+ "Waste - plastics","Waste - fugitive","Waste - textiles","Waste - glass",
101
+ "Electricity - EVs","Electricity - renewables","Electricity - standard"
102
+ ]
103
+ })
104
+
105
+ ref_data["combined"] = ref_data[["Cat1EN", "Cat2EN", "DescriptionCat2EN"]].agg(" ".join, axis=1)
106
+ ref_embeddings = model.encode(ref_data["combined"].tolist())
107
+
108
+ # Get unique categories for dropdowns
109
+ unique_cat1 = sorted(ref_data["Cat1EN"].unique().tolist())
110
+ unique_cat2 = sorted(ref_data["Cat2EN"].unique().tolist())
111
+
112
+ # ==================================================
113
+ # πŸ’Ύ Storage for batch data and corrections
114
+ # ==================================================
115
+ batch_results = None # Store current batch for editing
116
+ corrections_data = []
117
+
118
+ # ==================================================
119
+ # πŸ” Core Classification Logic
120
+ # ==================================================
121
+ def classify_transaction(text: str, top_k=3):
122
+ """Classify with top-K results"""
123
+ if not text or text.strip() == "":
124
+ return []
125
+
126
+ trans_emb = model.encode([text])[0]
127
+ scores = util.pytorch_cos_sim(torch.tensor(trans_emb), torch.tensor(ref_embeddings)).flatten()
128
+
129
+ top_k_indices = scores.topk(min(top_k, len(scores))).indices.tolist()
130
+ top_k_scores = scores.topk(min(top_k, len(scores))).values.tolist()
131
 
132
  results = []
133
+ for idx, score in zip(top_k_indices, top_k_scores):
134
+ results.append({
135
+ "cat1": ref_data.iloc[idx]["Cat1EN"],
136
+ "cat2": ref_data.iloc[idx]["Cat2EN"],
137
+ "score": float(score)
138
+ })
139
+
140
  return results
141
 
142
+ def classify_single(text: str):
143
+ """Single classification"""
144
+ if not text or text.strip() == "":
145
+ return "Please enter a transaction", "", 0.0
146
+
147
+ results = classify_transaction(text, top_k=1)
148
+ if not results:
149
+ return "No results", "", 0.0
150
+
151
+ return results[0]["cat1"], results[0]["cat2"], results[0]["score"]
152
+
153
  # ==================================================
154
+ # πŸ“‚ Batch Mapping with Review Table
155
  # ==================================================
156
  def map_csv_for_review(file):
157
+ """Process CSV and return editable results"""
158
  global batch_results
159
+
160
  try:
161
  if file is None:
162
  return None, "⚠️ Please upload a CSV file"
163
+
164
  df = pd.read_csv(file.name)
165
  if "transaction" not in df.columns:
166
+ return None, "❌ Error: Missing column 'transaction'"
167
+
168
  results = []
169
  for idx, text in enumerate(df["transaction"]):
170
  if pd.isna(text) or str(text).strip() == "":
171
  continue
172
+
173
+ top_matches = classify_transaction(str(text), top_k=1)
174
+ if not top_matches:
175
+ continue
176
+
177
  results.append({
178
  "ID": idx + 1,
179
+ "Transaction": str(text)[:80], # Truncate long text
180
+ "Cat1": top_matches[0]["cat1"],
181
+ "Cat2": top_matches[0]["cat2"],
182
+ "Confidence": round(top_matches[0]["score"], 3),
183
+ "Status": "βœ… OK" if top_matches[0]["score"] > 0.8 else "⚠️ Review",
184
  })
185
 
186
+ if not results:
187
+ return None, "❌ No valid transactions found"
188
 
189
+ batch_results = pd.DataFrame(results)
190
+ status_msg = f"βœ… Processed {len(batch_results)} transactions. Review and edit the 'Cat1' and 'Cat2' columns directly, then click 'Save Corrections'."
191
+
192
+ return batch_results, status_msg
193
+
194
  except Exception as e:
195
  return None, f"❌ Error: {str(e)}"
196
 
197
  # ==================================================
198
+ # πŸ’Ύ Save Corrections from Edited Table
199
  # ==================================================
200
+ def save_batch_corrections(edited_df):
201
+ """Save corrections from the edited dataframe"""
202
+ global corrections_data, batch_results
203
+
204
+ try:
205
+ if edited_df is None or len(edited_df) == 0:
206
+ return "⚠️ No data to save", None, None
207
+
208
+ # Convert to DataFrame if needed
209
+ if not isinstance(edited_df, pd.DataFrame):
210
+ edited_df = pd.DataFrame(edited_df)
211
+
212
+ # Store original predictions for comparison
213
+ original_predictions = {}
214
+ if batch_results is not None:
215
+ for idx, row in batch_results.iterrows():
216
+ original_predictions[row["ID"]] = {
217
+ "cat1": row["Cat1"],
218
+ "cat2": row["Cat2"]
219
+ }
220
+
221
+ # Count corrections (where edited differs from original prediction)
222
+ corrections_count = 0
223
+ new_corrections = []
224
+
225
+ for idx, row in edited_df.iterrows():
226
+ transaction_id = row["ID"]
227
+ original_cat1 = original_predictions.get(transaction_id, {}).get("cat1", "")
228
+ original_cat2 = original_predictions.get(transaction_id, {}).get("cat2", "")
229
+ edited_cat1 = row.get("Cat1", original_cat1)
230
+ edited_cat2 = row.get("Cat2", original_cat2)
231
+
232
+ # Check if correction was made
233
+ if edited_cat1 != original_cat1 or edited_cat2 != original_cat2:
234
+ corrections_count += 1
235
+ new_corrections.append({
236
+ "timestamp": datetime.now().isoformat(),
237
+ "transaction": row["Transaction"],
238
+ "original_cat1": original_cat1,
239
+ "original_cat2": original_cat2,
240
+ "corrected_cat1": edited_cat1,
241
+ "corrected_cat2": edited_cat2,
242
+ "confidence": row.get("Confidence", 0)
243
+ })
244
+
245
+ # Add to global corrections
246
+ corrections_data.extend(new_corrections)
247
+
248
+ # Save final CSV with corrections
249
+ tmp_dir = tempfile.mkdtemp()
250
+
251
+ # Save corrected results
252
+ corrected_file = os.path.join(tmp_dir, f"corrected_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv")
253
+ edited_df.to_csv(corrected_file, index=False)
254
+
255
+ # Save only corrections (training data)
256
+ training_file = None
257
+ if new_corrections:
258
+ training_df = pd.DataFrame(new_corrections)
259
+ training_file = os.path.join(tmp_dir, f"training_corrections_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv")
260
+ training_df.to_csv(training_file, index=False)
261
+
262
+ status_msg = f"βœ… Saved!\nπŸ“Š Total rows: {len(edited_df)}\n✏️ Corrections made: {corrections_count}\nπŸ’Ύ Total corrections in memory: {len(corrections_data)}"
263
+
264
+ return status_msg, corrected_file, training_file
265
+
266
+ except Exception as e:
267
+ return f"❌ Error saving: {str(e)}", None, None
268
 
269
+ # ==================================================
270
+ # πŸ“₯ Export All Corrections
271
+ # ==================================================
272
+ def export_all_corrections():
273
+ """Export all accumulated corrections"""
274
+ if not corrections_data:
275
+ return None, "⚠️ No corrections to export yet"
276
+
277
+ df = pd.DataFrame(corrections_data)
278
+ tmp_dir = tempfile.mkdtemp()
279
+ output_path = os.path.join(tmp_dir, f"all_corrections_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv")
280
+ df.to_csv(output_path, index=False)
281
+ return output_path, f"βœ… Exported {len(corrections_data)} total corrections for model training"
282
 
283
+ def show_corrections():
284
+ """Display current corrections"""
285
+ if not corrections_data:
286
+ return pd.DataFrame({"message": ["No corrections saved yet. Upload a CSV and make corrections!"]})
287
+ return pd.DataFrame(corrections_data)
 
 
 
288
 
289
+ # ==================================================
290
+ # πŸ”§ Model Comparison
291
+ # ==================================================
292
+ def compare_models_fixed(hf_model_url, file):
293
+ """Compare models"""
294
+ try:
295
+ if not hf_model_url or not hf_model_url.strip():
296
+ return "❌ Please provide a HuggingFace model ID", None
297
+
298
+ if file is None:
299
+ return "❌ Please upload a CSV file", None
300
+
301
+ df = pd.read_csv(file.name)
302
+ if "transaction" not in df.columns:
303
+ return "❌ Missing column 'transaction' in CSV", None
304
+
305
+ print(f"Loading comparison model: {hf_model_url}")
306
+ try:
307
+ hf_model = SetFitModel.from_pretrained(hf_model_url.strip())
308
+ except Exception as e:
309
+ return f"❌ Failed to load model '{hf_model_url}': {str(e)}", None
310
+
311
+ print("Encoding reference data...")
312
+ local_embs = model.encode(ref_data["combined"].tolist())
313
+ hf_embs = hf_model.encode(ref_data["combined"].tolist())
314
+
315
+ results = []
316
+ max_samples = min(50, len(df))
317
+
318
+ print(f"Comparing {max_samples} samples...")
319
+ for idx, text in enumerate(df["transaction"][:max_samples]):
320
+ if pd.isna(text) or str(text).strip() == "":
321
+ continue
322
+
323
+ text = str(text)
324
+
325
+ trans_emb_local = model.encode([text])[0]
326
+ scores_local = util.pytorch_cos_sim(
327
+ torch.tensor(trans_emb_local),
328
+ torch.tensor(local_embs)
329
+ ).flatten()
330
+ best_idx_local = scores_local.argmax().item()
331
+
332
+ trans_emb_hf = hf_model.encode([text])[0]
333
+ scores_hf = util.pytorch_cos_sim(
334
+ torch.tensor(trans_emb_hf),
335
+ torch.tensor(hf_embs)
336
+ ).flatten()
337
+ best_idx_hf = scores_hf.argmax().item()
338
+
339
+ local_cat1 = ref_data.iloc[best_idx_local]["Cat1EN"]
340
+ local_cat2 = ref_data.iloc[best_idx_local]["Cat2EN"]
341
+ hf_cat1 = ref_data.iloc[best_idx_hf]["Cat1EN"]
342
+ hf_cat2 = ref_data.iloc[best_idx_hf]["Cat2EN"]
343
+
344
+ match = "βœ…" if (local_cat1 == hf_cat1 and local_cat2 == hf_cat2) else "❌"
345
+
346
+ results.append({
347
+ "Transaction": text[:50] + "..." if len(text) > 50 else text,
348
+ "Local_Cat1": local_cat1,
349
+ "Local_Cat2": local_cat2,
350
+ "Local_Conf": round(float(scores_local[best_idx_local]), 3),
351
+ "HF_Cat1": hf_cat1,
352
+ "HF_Cat2": hf_cat2,
353
+ "HF_Conf": round(float(scores_hf[best_idx_hf]), 3),
354
+ "Match": match
355
+ })
356
+
357
+ if not results:
358
+ return "❌ No valid transactions to compare", None
359
+
360
+ result_df = pd.DataFrame(results)
361
+ matches = (result_df["Match"] == "βœ…").sum()
362
+ match_rate = round(matches / len(result_df) * 100, 2)
363
+
364
+ summary = f"""βœ… Comparison Complete!
365
+ πŸ“Š Analyzed: {len(result_df)} transactions
366
+ 🎯 Match Rate: {match_rate}% ({matches}/{len(result_df)} matches)
367
+ πŸ“ Your Model: yassine123Z/EmissionFactor-mapper2-v2
368
+ πŸ“ Compared with: {hf_model_url}"""
369
+
370
+ return summary, result_df
371
+
372
+ except Exception as e:
373
+ import traceback
374
+ error_details = traceback.format_exc()
375
+ return f"❌ Error: {str(e)}\n\nDetails:\n{error_details}", None
376
 
377
  # ==================================================
378
+ # πŸ–₯️ Main Gradio UI
379
  # ==================================================
380
+ with gr.Blocks(title="Transaction Category Classifier", theme=gr.themes.Soft()) as main_ui:
381
+ gr.Markdown("# 🧾 Transaction Category Classifier")
382
+ gr.Markdown("Upload CSV β†’ Review predictions in editable table β†’ Make corrections β†’ Save for training")
383
+
384
+ with gr.Tab("πŸ”Ή Single Transaction"):
385
+ with gr.Row():
386
+ with gr.Column():
387
+ text_input = gr.Textbox(
388
+ label="Transaction Description",
389
+ placeholder="e.g., Plane ticket to Barcelona",
390
+ lines=2
391
+ )
392
+ btn_submit = gr.Button("πŸ” Classify", variant="primary")
393
+
394
+ with gr.Column():
395
+ cat1_out = gr.Textbox(label="Predicted Category 1", interactive=False)
396
+ cat2_out = gr.Textbox(label="Predicted Category 2", interactive=False)
397
+ score_out = gr.Number(label="Confidence Score")
398
+
399
+ btn_submit.click(
400
+ fn=classify_single,
401
+ inputs=text_input,
402
+ outputs=[cat1_out, cat2_out, score_out]
403
+ )
404
 
405
  with gr.Tab("πŸ“‚ Batch Review & Correct"):
406
+ gr.Markdown("""
407
+ ## πŸ“‹ How to use:
408
+ 1. **Upload CSV** with 'transaction' column
409
+ 2. **Review & Edit** the table below - modify 'Cat1' and 'Cat2' columns directly
410
+ 3. **Save** corrections for model training
411
+ """)
412
+
413
+ csv_input = gr.File(label="πŸ“ Upload CSV file", file_types=[".csv"])
414
+ btn_process = gr.Button("πŸš€ Process & Load for Review", variant="primary", size="lg")
415
+
416
+ process_status = gr.Textbox(label="Status", lines=3, interactive=False)
417
+
418
+ gr.Markdown("### ✏️ Editable Results - Modify categories directly in the table:")
419
+ review_table = gr.DataFrame(
420
+ label="Review & Edit Predictions (Edit 'Cat1' and 'Cat2' columns directly)",
421
  interactive=True,
422
+ wrap=True
 
423
  )
424
+
425
+ gr.Markdown("**Instructions:** Click any cell in 'Cat1' or 'Cat2' to edit directly. Changes will be tracked as corrections.")
426
+
427
+ with gr.Row():
428
+ btn_save = gr.Button("πŸ’Ύ Save All Corrections", variant="primary", size="lg")
429
+ btn_download_corrected = gr.File(label="πŸ“₯ Download Corrected Results")
430
+ btn_download_training = gr.File(label="πŸ“₯ Download Training Data (Corrections Only)")
431
+
432
+ save_status = gr.Textbox(label="Save Status", lines=4, interactive=False)
433
+
434
  btn_process.click(
435
+ fn=map_csv_for_review,
436
+ inputs=csv_input,
437
  outputs=[review_table, process_status]
438
  )
439
+
 
 
 
 
 
 
 
 
440
  btn_save.click(
441
  fn=save_batch_corrections,
442
  inputs=review_table,
443
  outputs=[save_status, btn_download_corrected, btn_download_training]
444
  )
445
 
446
+ with gr.Tab("πŸ“Š All Corrections History"):
447
+ gr.Markdown("### πŸ“œ View All Saved Corrections")
448
+ gr.Markdown("All corrections from all batches are stored here for model retraining.")
449
+
450
+ btn_refresh = gr.Button("πŸ”„ Refresh List")
451
+ corrections_table = gr.DataFrame(label="All Corrections", interactive=False)
452
+
453
+ with gr.Row():
454
+ btn_export_all = gr.Button("πŸ“€ Export All Corrections", variant="primary")
455
+ export_all_file = gr.File(label="πŸ“₯ Download All Corrections CSV")
456
+ export_status = gr.Textbox(label="Export Status", interactive=False)
457
+
458
+ btn_refresh.click(fn=show_corrections, outputs=corrections_table)
459
+ btn_export_all.click(fn=export_all_corrections, outputs=[export_all_file, export_status])
460
+
461
+ with gr.Tab("πŸ”¬ Compare Models"):
462
+ gr.Markdown("## πŸ” Model Comparison Tool")
463
+ gr.Markdown("Compare your model with any HuggingFace SetFit model.")
464
+
465
+ hf_model_url = gr.Textbox(
466
+ label="HuggingFace Model ID",
467
+ placeholder="e.g., sentence-transformers/all-MiniLM-L6-v2"
468
+ )
469
+
470
+ compare_file = gr.File(label="Upload Test CSV", file_types=[".csv"])
471
+ compare_btn = gr.Button("πŸ”¬ Compare Models", variant="primary")
472
+
473
+ compare_summary = gr.Textbox(label="Comparison Summary", lines=6, interactive=False)
474
+ compare_results = gr.DataFrame(label="Detailed Results", interactive=False)
475
+
476
+ compare_btn.click(
477
+ fn=compare_models_fixed,
478
+ inputs=[hf_model_url, compare_file],
479
+ outputs=[compare_summary, compare_results]
480
+ )
481
+
482
  # ==================================================
483
+ # 🌐 Mount Gradio App
484
  # ==================================================
485
+ app = gr.mount_gradio_app(app, main_ui, path="/")
486
+
487
+ # ==================================================
488
+ # 🧾 REST API Endpoints
489
+ # ==================================================
490
+ class TransactionsRequest(BaseModel):
491
+ transactions: List[str]
492
+
493
+ @app.get("/health")
494
+ def health_check():
495
+ return {
496
+ "status": "healthy",
497
+ "model_loaded": model is not None,
498
+ "corrections_count": len(corrections_data)
499
+ }
500
+
501
+ @app.post("/map_categories")
502
+ def map_categories(request: TransactionsRequest):
503
+ results = []
504
+ for text in request.transactions:
505
+ cat1, cat2, score = classify_single(text)
506
+ results.append({
507
+ "input_text": text,
508
+ "best_Cat1": cat1,
509
+ "best_Cat2": cat2,
510
+ "similarity": score
511
+ })
512
+ return {"matches": results}
513
+
514
+ @app.get("/corrections")
515
+ def get_corrections():
516
+ return {
517
+ "corrections": corrections_data,
518
+ "count": len(corrections_data)
519
+ }
520
+
521
+ print("βœ… App initialized successfully!")
522
+ print("πŸ“ All features in one interface: /")
523
+ print("πŸ₯ Health Check: /health")
524
+ print("πŸ”Œ API: /map_categories, /corrections")