yassine123Z commited on
Commit
1e16b3a
Β·
verified Β·
1 Parent(s): 9c5d1b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +314 -116
app.py CHANGED
@@ -9,6 +9,7 @@ import torch
9
  import gradio as gr
10
  import tempfile
11
  import os
 
12
 
13
  # ==================================================
14
  # πŸš€ Initialize FastAPI
@@ -23,132 +24,355 @@ model = SetFitModel.from_pretrained("yassine123Z/EmissionFactor-mapper2-v2")
23
  # ==================================================
24
  # πŸ“˜ Reference Categories
25
  # ==================================================
26
- ref_data = pd.DataFrame({ "Cat1EN": [ "Purchase of goods","Purchase of goods","Purchase of goods","Purchase of goods", "Purchase of goods","Purchase of goods","Purchase of goods","Purchase of goods", "Purchase of goods","Purchase of goods","Purchase of materials","Purchase of materials", "Purchase of materials","Purchase of materials","Purchase of materials","Purchase of materials", "Purchase of services","Purchase of services","Purchase of services","Purchase of services", "Purchase of services","Purchase of services","Purchase of services","Purchase of services", "Purchase of services","Purchase of services","Purchase of services","Purchase of services", "Purchase of services","Purchase of services","Food & beverages","Food & beverages", "Food & beverages","Food & beverages","Food & beverages","Food & beverages", "Food & beverages","Food & beverages","Food & beverages","Food & beverages", "Heating and air conditioning","Heating and air conditioning","Fuels","Fuels","Fuels","Fuels", "Fuels","Fuels", "Mobility (freight)","Mobility (freight)","Mobility (freight)","Mobility (freight)", "Mobility (freight)", "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)", "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)","Mobility (passengers)", "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)","Mobility (passengers)", "Process and fugitive emissions","Process and fugitive emissions", "Process and fugitive emissions", "Waste treatment","Waste treatment","Waste treatment", "Waste treatment","Waste treatment","Waste treatment","Waste treatment","Waste treatment", "Waste treatment","Waste treatment","Waste treatment","Waste treatment", "Use of electricity","Use of electricity","Use of electricity" ],
27
-
28
- "Cat2EN": [ "Sporting goods","Buildings","Office supplies","Water consumption", "Household appliances","Electrical equipment","Machinery and equipment","Furniture", "Textiles and clothing","Vehicles","Construction materials","Organic materials", "Paper and cardboard","Plastics and rubber","Chemicals","Refrigerants and others", "Equipment rental","Building rental","Furniture rental","Vehicle rental and maintenance", "Information and cultural services","Catering services","Health services","Specialized craft services", "Administrative / consulting services","Cleaning services","IT services","Logistics services", "Marketing / advertising services","Technical services","Alcoholic beverages","Non-alcoholic beverages", "Condiments","Desserts","Fruits and vegetables","Fats and oils","Prepared / cooked meals", "Animal products","Cereal products","Dairy products","Heat and steam","Air conditioning and refrigeration", "Fossil fuels","Mobile fossil fuels","Organic fuels","Gaseous fossil fuels","Liquid fossil fuels", "Solid fossil fuels", "Air transport","Ship transport","Truck transport","Combined transport", "Train transport", "Air transport","Coach / Urban bus","Ship transport","Combined transport", "E-Bike","Accommodation / Events","Soft mobility","Motorcycle / Scooter","Train transport", "Public transport","Car", "Agriculture","Global warming potential","Industrial processes", "Commercial and industrial","Wastewater","Electrical equipment","Households and similar", "Metal","Organic materials","Paper and cardboard","Batteries and accumulators","Plastics", "Fugitive process emissions","Textiles","Glass", "Electricity for electric vehicles","Renewables","Standard" ],
29
-
30
- "DescriptionCat2EN": [ "Goods purchase - sports","Goods purchase - buildings","Goods purchase - office items","Goods purchase - water", "Goods purchase - appliances","Goods purchase - electricals","Goods purchase - machinery","Goods purchase - furniture", "Goods purchase - textiles","Goods purchase - vehicles","Material purchase - construction","Material purchase - organic", "Material purchase - paper","Material purchase - plastics","Material purchase - chemicals","Material purchase - refrigerants", "Service - equipment rental","Service - building rental","Service - furniture rental","Service - vehicles", "Service - info/culture","Service - catering","Service - healthcare","Service - crafts", "Service - admin/consulting","Service - cleaning","Service - IT","Service - logistics", "Service - marketing","Service - technical","Beverages - alcoholic","Beverages - non-alcoholic", "Food condiments","Food desserts","Food fruits & vegetables","Food fats & oils","Prepared meals", "Animal-based food","Cereal-based food","Dairy products","Heating - heat & steam","Heating - cooling/refrigeration", "Fuel - fossil","Fuel - mobile fossil","Fuel - organic","Fuel - gaseous","Fuel - liquid","Fuel - solid", "Freight transport - air","Freight transport - ship","Freight transport - truck","Freight transport - combined", "Freight transport - train", "Passenger transport - air","Passenger transport - bus","Passenger transport - ship", "Passenger transport - combined","Passenger transport - e-bike","Passenger transport - accommodation/events", "Passenger transport - soft mobility","Passenger transport - scooter/motorbike","Passenger transport - train", "Passenger transport - public","Passenger transport - car", "Emissions - agriculture","Emissions - warming potential", "Emissions - industry", "Waste - commercial/industrial","Waste - wastewater","Waste - electricals", "Waste - households","Waste - metals","Waste - organics","Waste - paper","Waste - batteries", "Waste - plastics","Waste - fugitive","Waste - textiles","Waste - glass", "Electricity - EVs","Electricity - renewables","Electricity - standard" ]
31
-
32
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  ref_data["combined"] = ref_data[["Cat1EN", "Cat2EN", "DescriptionCat2EN"]].agg(" ".join, axis=1)
34
  ref_embeddings = model.encode(ref_data["combined"].tolist())
35
 
 
 
 
 
36
  # ==================================================
37
- # πŸ” Core Classification Logic
38
  # ==================================================
39
- def classify_transaction(text: str):
40
- trans_emb = model.encode([text])[0]
41
- scores = util.pytorch_cos_sim(torch.tensor(trans_emb), torch.tensor(ref_embeddings)).flatten()
42
- best_idx = scores.argmax().item()
43
- cat1 = ref_data.iloc[best_idx]["Cat1EN"]
44
- cat2 = ref_data.iloc[best_idx]["Cat2EN"]
45
- score = float(scores[best_idx])
46
- return cat1, cat2, score
 
 
 
 
 
47
 
48
  # ==================================================
49
- # πŸ“‚ Batch Mapping
50
  # ==================================================
51
- def map_csv(file):
52
- df = pd.read_csv(file.name)
53
- if "transaction" not in df.columns:
54
- return "Error: Missing column 'transaction'.", None
 
 
 
 
55
 
56
  results = []
57
- for text in df["transaction"]:
58
- cat1, cat2, score = classify_transaction(text)
59
  results.append({
60
- "transaction": text,
61
- "Predicted Category 1": cat1,
62
- "Predicted Category 2": cat2,
63
- "Similarity Score": score
64
  })
 
 
65
 
66
- result_df = pd.DataFrame(results)
67
- tmp_dir = tempfile.mkdtemp()
68
- output_path = os.path.join(tmp_dir, "matched_results.csv")
69
- result_df.to_csv(output_path, index=False)
70
- return result_df, output_path
71
 
72
  # ==================================================
73
- # 🧠 Model Comparison Logic (no external URL call)
74
  # ==================================================
75
- def compare_models_ui(hf_model_url, file):
76
- import pandas as pd
77
- from setfit import SetFitModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
 
 
 
 
 
79
  try:
 
 
 
80
  df = pd.read_csv(file.name)
81
  if "transaction" not in df.columns:
82
  return "❌ Missing column 'transaction' in CSV.", None
83
 
84
- # Load the models
85
- local_model = SetFitModel.from_pretrained("yassine123Z/EmissionFactor-mapper2-v2")
86
- hf_model = SetFitModel.from_pretrained(hf_model_url)
 
 
 
 
87
 
88
  # Compare predictions
89
- local_preds, hf_preds, matches = [], [], []
90
- for text in df["transaction"]:
91
- local_pred = local_model.predict([text])[0]
92
- hf_pred = hf_model.predict([text])[0]
93
- local_preds.append(local_pred)
94
- hf_preds.append(hf_pred)
95
- matches.append(1.0 if local_pred == hf_pred else 0.0)
96
-
97
- df["local_pred"] = local_preds
98
- df["hf_pred"] = hf_preds
99
- df["match"] = matches
100
- match_rate = round(df["match"].mean() * 100, 2)
101
-
102
- summary = f"βœ… Match rate: {match_rate}% ({len(df)} rows)"
103
- return summary, df.head(10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  except Exception as e:
105
- return f"❌ Exception: {str(e)}", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  # ==================================================
108
  # πŸ–₯️ Gradio UI: Main App
109
  # ==================================================
110
- with gr.Blocks(title="Transaction Category Classifier") as gradio_ui:
111
- gr.Markdown("## 🧾 Transaction Category Classifier")
112
- gr.Markdown("Enter a transaction manually or upload a CSV file to classify multiple transactions.")
113
 
114
  with gr.Tab("πŸ”Ή Single Transaction"):
115
- text_input = gr.Textbox(label="Transaction Description", placeholder="e.g., Plane ticket to Barcelona")
116
- btn_submit = gr.Button("Submit")
117
- cat1_out = gr.Label(label="Predicted Category 1")
118
- cat2_out = gr.Label(label="Predicted Category 2")
119
- score_out = gr.Number(label="Similarity Score")
120
- btn_submit.click(fn=classify_transaction, inputs=text_input, outputs=[cat1_out, cat2_out, score_out])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  with gr.Tab("πŸ“‚ Batch CSV Upload"):
123
- csv_input = gr.File(label="Upload CSV file with 'transaction' column", file_types=[".csv"])
124
- btn_process = gr.Button("Process CSV")
125
- csv_output = gr.DataFrame(label="Matched Results")
126
- download_file = gr.File(label="Download Results CSV")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- def process_and_return(file):
129
- df, output_path = map_csv(file)
130
- if isinstance(df, str):
131
- return None, None
132
- return df, output_path
133
-
134
- btn_process.click(fn=process_and_return, inputs=csv_input, outputs=[csv_output, download_file])
 
 
 
 
 
 
 
 
 
135
 
136
  # ==================================================
137
- # 🧠 Gradio UI: Model Comparison Page
138
  # ==================================================
139
- with gr.Blocks(title="Model Comparison Tool") as compare_ui:
140
- gr.Markdown("### πŸ” Model Comparison Tool")
141
- hf_model_url = gr.Textbox(label="Hugging Face model URL")
142
- file = gr.File(label="Upload test dataset (CSV)")
143
- compare_btn = gr.Button("Compare Models")
144
- output_text = gr.Textbox(label="Summary")
145
- output_table = gr.DataFrame(label="Sample results")
146
- compare_btn.click(fn=compare_models_ui, inputs=[hf_model_url, file], outputs=[output_text, output_table])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  # ==================================================
149
  # 🌐 Mount Gradio inside FastAPI
150
  # ==================================================
151
- app = gr.mount_gradio_app(app, gradio_ui, path="/ui")
152
  app = gr.mount_gradio_app(app, compare_ui, path="/compare")
153
 
154
  # ==================================================
@@ -157,15 +381,11 @@ app = gr.mount_gradio_app(app, compare_ui, path="/compare")
157
  class TransactionsRequest(BaseModel):
158
  transactions: List[str]
159
 
160
- @app.get("/")
161
- def read_root():
162
- return {"status": "ok", "message": "Use /ui or /compare for Gradio, or /map_categories for API."}
163
-
164
  @app.post("/map_categories")
165
  def map_categories(request: TransactionsRequest):
166
  results = []
167
  for text in request.transactions:
168
- cat1, cat2, score = classify_transaction(text)
169
  results.append({
170
  "input_text": text,
171
  "best_Cat1": cat1,
@@ -174,29 +394,7 @@ def map_categories(request: TransactionsRequest):
174
  })
175
  return {"matches": results}
176
 
177
- @app.post("/compare_models")
178
- def compare_models(request: dict):
179
- local_model = SetFitModel.from_pretrained("yassine123Z/EmissionFactor-mapper2-v2")
180
- hf_model = SetFitModel.from_pretrained(request["hf_model"])
181
- df = pd.DataFrame(request["data"])
182
- if "transaction" not in df.columns:
183
- return {"error": "Missing column 'transaction'."}
184
-
185
- local_preds, hf_preds, matches = [], [], []
186
- for text in df["transaction"]:
187
- local_pred = local_model.predict([text])[0]
188
- hf_pred = hf_model.predict([text])[0]
189
- local_preds.append(local_pred)
190
- hf_preds.append(hf_pred)
191
- matches.append(1.0 if local_pred == hf_pred else 0.0)
192
-
193
- df["local_pred"] = local_preds
194
- df["hf_pred"] = hf_preds
195
- df["match"] = matches
196
- match_rate = round(df["match"].mean() * 100, 2)
197
-
198
- return {
199
- "match_rate": match_rate,
200
- "total_records": len(df),
201
- "sample_results": df.head(10).to_dict(orient="records")
202
- }
 
9
  import gradio as gr
10
  import tempfile
11
  import os
12
+ from datetime import datetime
13
 
14
  # ==================================================
15
  # πŸš€ Initialize FastAPI
 
24
  # ==================================================
25
  # πŸ“˜ Reference Categories
26
  # ==================================================
27
+ ref_data = pd.DataFrame({
28
+ "Cat1EN": [
29
+ "Purchase of goods","Purchase of goods","Purchase of goods","Purchase of goods",
30
+ "Purchase of goods","Purchase of goods","Purchase of goods","Purchase of goods",
31
+ "Purchase of goods","Purchase of goods","Purchase of materials","Purchase of materials",
32
+ "Purchase of materials","Purchase of materials","Purchase of materials","Purchase of materials",
33
+ "Purchase of services","Purchase of services","Purchase of services","Purchase of services",
34
+ "Purchase of services","Purchase of services","Purchase of services","Purchase of services",
35
+ "Purchase of services","Purchase of services","Purchase of services","Purchase of services",
36
+ "Purchase of services","Purchase of services","Food & beverages","Food & beverages",
37
+ "Food & beverages","Food & beverages","Food & beverages","Food & beverages",
38
+ "Food & beverages","Food & beverages","Food & beverages","Food & beverages",
39
+ "Heating and air conditioning","Heating and air conditioning","Fuels","Fuels","Fuels","Fuels",
40
+ "Fuels","Fuels",
41
+ "Mobility (freight)","Mobility (freight)","Mobility (freight)","Mobility (freight)",
42
+ "Mobility (freight)",
43
+ "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)",
44
+ "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)","Mobility (passengers)",
45
+ "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)","Mobility (passengers)",
46
+ "Process and fugitive emissions","Process and fugitive emissions",
47
+ "Process and fugitive emissions",
48
+ "Waste treatment","Waste treatment","Waste treatment",
49
+ "Waste treatment","Waste treatment","Waste treatment","Waste treatment","Waste treatment",
50
+ "Waste treatment","Waste treatment","Waste treatment","Waste treatment",
51
+ "Use of electricity","Use of electricity","Use of electricity"
52
+ ],
53
+ "Cat2EN": [
54
+ "Sporting goods","Buildings","Office supplies","Water consumption",
55
+ "Household appliances","Electrical equipment","Machinery and equipment","Furniture",
56
+ "Textiles and clothing","Vehicles","Construction materials","Organic materials",
57
+ "Paper and cardboard","Plastics and rubber","Chemicals","Refrigerants and others",
58
+ "Equipment rental","Building rental","Furniture rental","Vehicle rental and maintenance",
59
+ "Information and cultural services","Catering services","Health services","Specialized craft services",
60
+ "Administrative / consulting services","Cleaning services","IT services","Logistics services",
61
+ "Marketing / advertising services","Technical services","Alcoholic beverages","Non-alcoholic beverages",
62
+ "Condiments","Desserts","Fruits and vegetables","Fats and oils","Prepared / cooked meals",
63
+ "Animal products","Cereal products","Dairy products","Heat and steam","Air conditioning and refrigeration",
64
+ "Fossil fuels","Mobile fossil fuels","Organic fuels","Gaseous fossil fuels","Liquid fossil fuels",
65
+ "Solid fossil fuels",
66
+ "Air transport","Ship transport","Truck transport","Combined transport",
67
+ "Train transport",
68
+ "Air transport","Coach / Urban bus","Ship transport","Combined transport",
69
+ "E-Bike","Accommodation / Events","Soft mobility","Motorcycle / Scooter","Train transport",
70
+ "Public transport","Car",
71
+ "Agriculture","Global warming potential","Industrial processes",
72
+ "Commercial and industrial","Wastewater","Electrical equipment","Households and similar",
73
+ "Metal","Organic materials","Paper and cardboard","Batteries and accumulators","Plastics",
74
+ "Fugitive process emissions","Textiles","Glass",
75
+ "Electricity for electric vehicles","Renewables","Standard"
76
+ ],
77
+ "DescriptionCat2EN": [
78
+ "Goods purchase - sports","Goods purchase - buildings","Goods purchase - office items","Goods purchase - water",
79
+ "Goods purchase - appliances","Goods purchase - electricals","Goods purchase - machinery","Goods purchase - furniture",
80
+ "Goods purchase - textiles","Goods purchase - vehicles","Material purchase - construction","Material purchase - organic",
81
+ "Material purchase - paper","Material purchase - plastics","Material purchase - chemicals","Material purchase - refrigerants",
82
+ "Service - equipment rental","Service - building rental","Service - furniture rental","Service - vehicles",
83
+ "Service - info/culture","Service - catering","Service - healthcare","Service - crafts",
84
+ "Service - admin/consulting","Service - cleaning","Service - IT","Service - logistics",
85
+ "Service - marketing","Service - technical","Beverages - alcoholic","Beverages - non-alcoholic",
86
+ "Food condiments","Food desserts","Food fruits & vegetables","Food fats & oils","Prepared meals",
87
+ "Animal-based food","Cereal-based food","Dairy products","Heating - heat & steam","Heating - cooling/refrigeration",
88
+ "Fuel - fossil","Fuel - mobile fossil","Fuel - organic","Fuel - gaseous","Fuel - liquid","Fuel - solid",
89
+ "Freight transport - air","Freight transport - ship","Freight transport - truck","Freight transport - combined",
90
+ "Freight transport - train",
91
+ "Passenger transport - air","Passenger transport - bus","Passenger transport - ship",
92
+ "Passenger transport - combined","Passenger transport - e-bike","Passenger transport - accommodation/events",
93
+ "Passenger transport - soft mobility","Passenger transport - scooter/motorbike","Passenger transport - train",
94
+ "Passenger transport - public","Passenger transport - car",
95
+ "Emissions - agriculture","Emissions - warming potential",
96
+ "Emissions - industry",
97
+ "Waste - commercial/industrial","Waste - wastewater","Waste - electricals",
98
+ "Waste - households","Waste - metals","Waste - organics","Waste - paper","Waste - batteries",
99
+ "Waste - plastics","Waste - fugitive","Waste - textiles","Waste - glass",
100
+ "Electricity - EVs","Electricity - renewables","Electricity - standard"
101
+ ]
102
+ })
103
+
104
  ref_data["combined"] = ref_data[["Cat1EN", "Cat2EN", "DescriptionCat2EN"]].agg(" ".join, axis=1)
105
  ref_embeddings = model.encode(ref_data["combined"].tolist())
106
 
107
+ # Get unique categories for dropdowns
108
+ unique_cat1 = sorted(ref_data["Cat1EN"].unique().tolist())
109
+ unique_cat2 = sorted(ref_data["Cat2EN"].unique().tolist())
110
+
111
  # ==================================================
112
+ # πŸ’Ύ Corrections Storage (in-memory, use DB in production)
113
  # ==================================================
114
+ corrections_data = []
115
+
116
+ def save_correction(transaction, predicted_cat1, predicted_cat2, correct_cat1, correct_cat2):
117
+ """Save user correction for future model improvement"""
118
+ corrections_data.append({
119
+ "timestamp": datetime.now().isoformat(),
120
+ "transaction": transaction,
121
+ "predicted_cat1": predicted_cat1,
122
+ "predicted_cat2": predicted_cat2,
123
+ "correct_cat1": correct_cat1,
124
+ "correct_cat2": correct_cat2
125
+ })
126
+ return f"βœ… Correction saved! Total corrections: {len(corrections_data)}"
127
 
128
  # ==================================================
129
+ # πŸ” Core Classification Logic with Top-K
130
  # ==================================================
131
+ def classify_transaction(text: str, top_k=3):
132
+ """Classify with top-K results for review"""
133
+ trans_emb = model.encode([text])[0]
134
+ scores = util.pytorch_cos_sim(torch.tensor(trans_emb), torch.tensor(ref_embeddings)).flatten()
135
+
136
+ # Get top-k matches
137
+ top_k_indices = scores.topk(top_k).indices.tolist()
138
+ top_k_scores = scores.topk(top_k).values.tolist()
139
 
140
  results = []
141
+ for idx, score in zip(top_k_indices, top_k_scores):
 
142
  results.append({
143
+ "cat1": ref_data.iloc[idx]["Cat1EN"],
144
+ "cat2": ref_data.iloc[idx]["Cat2EN"],
145
+ "score": float(score)
 
146
  })
147
+
148
+ return results
149
 
150
+ def classify_single(text: str):
151
+ """For simple single classification"""
152
+ results = classify_transaction(text, top_k=1)
153
+ return results[0]["cat1"], results[0]["cat2"], results[0]["score"]
 
154
 
155
  # ==================================================
156
+ # πŸ“‚ Batch Mapping with Review
157
  # ==================================================
158
+ def map_csv_with_review(file):
159
+ """Process CSV and return results for review"""
160
+ try:
161
+ df = pd.read_csv(file.name)
162
+ if "transaction" not in df.columns:
163
+ return "❌ Error: Missing column 'transaction'.", None, None
164
+
165
+ results = []
166
+ for idx, text in enumerate(df["transaction"]):
167
+ top_matches = classify_transaction(text, top_k=3)
168
+ results.append({
169
+ "row_id": idx,
170
+ "transaction": text,
171
+ "cat1_pred": top_matches[0]["cat1"],
172
+ "cat2_pred": top_matches[0]["cat2"],
173
+ "confidence": round(top_matches[0]["score"], 3),
174
+ "cat1_alt1": top_matches[1]["cat1"] if len(top_matches) > 1 else "",
175
+ "cat2_alt1": top_matches[1]["cat2"] if len(top_matches) > 1 else "",
176
+ "confidence_alt1": round(top_matches[1]["score"], 3) if len(top_matches) > 1 else 0,
177
+ "status": "βœ… High" if top_matches[0]["score"] > 0.8 else "⚠️ Review"
178
+ })
179
+
180
+ result_df = pd.DataFrame(results)
181
+
182
+ # Save temporary file
183
+ tmp_dir = tempfile.mkdtemp()
184
+ output_path = os.path.join(tmp_dir, "mapped_results.csv")
185
+ result_df.to_csv(output_path, index=False)
186
+
187
+ return result_df, output_path, f"βœ… Processed {len(result_df)} transactions"
188
+
189
+ except Exception as e:
190
+ return f"❌ Error: {str(e)}", None, None
191
 
192
+ # ==================================================
193
+ # πŸ”§ Model Comparison (Fixed)
194
+ # ==================================================
195
+ def compare_models_fixed(hf_model_url, file):
196
+ """Fixed comparison function"""
197
  try:
198
+ if not hf_model_url or not file:
199
+ return "❌ Please provide both model URL and CSV file", None
200
+
201
  df = pd.read_csv(file.name)
202
  if "transaction" not in df.columns:
203
  return "❌ Missing column 'transaction' in CSV.", None
204
 
205
+ # Load models
206
+ local_model = model # Use already loaded model
207
+ hf_model = SetFitModel.from_pretrained(hf_model_url.strip())
208
+
209
+ # Get embeddings
210
+ local_embs = local_model.encode(ref_data["combined"].tolist())
211
+ hf_embs = hf_model.encode(ref_data["combined"].tolist())
212
 
213
  # Compare predictions
214
+ results = []
215
+ for text in df["transaction"][:50]: # Limit to 50 for speed
216
+ # Local prediction
217
+ trans_emb_local = local_model.encode([text])[0]
218
+ scores_local = util.pytorch_cos_sim(torch.tensor(trans_emb_local), torch.tensor(local_embs)).flatten()
219
+ best_idx_local = scores_local.argmax().item()
220
+
221
+ # HF prediction
222
+ trans_emb_hf = hf_model.encode([text])[0]
223
+ scores_hf = util.pytorch_cos_sim(torch.tensor(trans_emb_hf), torch.tensor(hf_embs)).flatten()
224
+ best_idx_hf = scores_hf.argmax().item()
225
+
226
+ match = "βœ…" if best_idx_local == best_idx_hf else "❌"
227
+
228
+ results.append({
229
+ "transaction": text,
230
+ "local_cat1": ref_data.iloc[best_idx_local]["Cat1EN"],
231
+ "local_cat2": ref_data.iloc[best_idx_local]["Cat2EN"],
232
+ "local_score": round(float(scores_local[best_idx_local]), 3),
233
+ "hf_cat1": ref_data.iloc[best_idx_hf]["Cat1EN"],
234
+ "hf_cat2": ref_data.iloc[best_idx_hf]["Cat2EN"],
235
+ "hf_score": round(float(scores_hf[best_idx_hf]), 3),
236
+ "match": match
237
+ })
238
+
239
+ result_df = pd.DataFrame(results)
240
+ matches = (result_df["match"] == "βœ…").sum()
241
+ match_rate = round(matches / len(result_df) * 100, 2)
242
+
243
+ summary = f"βœ… Compared {len(result_df)} transactions\nπŸ“Š Match rate: {match_rate}% ({matches}/{len(result_df)})"
244
+ return summary, result_df
245
+
246
  except Exception as e:
247
+ return f"❌ Error: {str(e)}", None
248
+
249
+ # ==================================================
250
+ # πŸ“₯ Export Corrections
251
+ # ==================================================
252
+ def export_corrections():
253
+ """Export corrections to CSV"""
254
+ if not corrections_data:
255
+ return None, "⚠️ No corrections to export"
256
+
257
+ df = pd.DataFrame(corrections_data)
258
+ tmp_dir = tempfile.mkdtemp()
259
+ output_path = os.path.join(tmp_dir, f"corrections_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv")
260
+ df.to_csv(output_path, index=False)
261
+ return output_path, f"βœ… Exported {len(corrections_data)} corrections"
262
 
263
  # ==================================================
264
  # πŸ–₯️ Gradio UI: Main App
265
  # ==================================================
266
+ with gr.Blocks(title="Transaction Category Classifier", theme=gr.themes.Soft()) as gradio_ui:
267
+ gr.Markdown("# 🧾 Transaction Category Classifier")
268
+ gr.Markdown("Classify transactions and review/correct predictions to improve the model.")
269
 
270
  with gr.Tab("πŸ”Ή Single Transaction"):
271
+ with gr.Row():
272
+ with gr.Column():
273
+ text_input = gr.Textbox(
274
+ label="Transaction Description",
275
+ placeholder="e.g., Plane ticket to Barcelona",
276
+ lines=2
277
+ )
278
+ btn_submit = gr.Button("πŸ” Classify", variant="primary")
279
+
280
+ with gr.Column():
281
+ cat1_out = gr.Textbox(label="Predicted Category 1")
282
+ cat2_out = gr.Textbox(label="Predicted Category 2")
283
+ score_out = gr.Number(label="Confidence Score")
284
+
285
+ gr.Markdown("### ✏️ Review & Correct")
286
+ with gr.Row():
287
+ correct_cat1 = gr.Dropdown(choices=unique_cat1, label="Correct Category 1")
288
+ correct_cat2 = gr.Dropdown(choices=unique_cat2, label="Correct Category 2")
289
+
290
+ btn_save_correction = gr.Button("πŸ’Ύ Save Correction")
291
+ correction_status = gr.Textbox(label="Status")
292
+
293
+ # Event handlers
294
+ btn_submit.click(
295
+ fn=classify_single,
296
+ inputs=text_input,
297
+ outputs=[cat1_out, cat2_out, score_out]
298
+ )
299
+
300
+ btn_save_correction.click(
301
+ fn=save_correction,
302
+ inputs=[text_input, cat1_out, cat2_out, correct_cat1, correct_cat2],
303
+ outputs=correction_status
304
+ )
305
 
306
  with gr.Tab("πŸ“‚ Batch CSV Upload"):
307
+ gr.Markdown("Upload a CSV file with a 'transaction' column to classify multiple transactions.")
308
+
309
+ csv_input = gr.File(label="Upload CSV file", file_types=[".csv"])
310
+ btn_process = gr.Button("πŸš€ Process CSV", variant="primary")
311
+
312
+ process_status = gr.Textbox(label="Status")
313
+ csv_output = gr.DataFrame(label="Classification Results (scroll right for alternatives)")
314
+ download_file = gr.File(label="πŸ“₯ Download Results CSV")
315
+
316
+ btn_process.click(
317
+ fn=map_csv_with_review,
318
+ inputs=csv_input,
319
+ outputs=[csv_output, download_file, process_status]
320
+ )
321
+
322
+ gr.Markdown("""
323
+ **Legend:**
324
+ - βœ… High: Confidence > 80%
325
+ - ⚠️ Review: Confidence < 80% - please review
326
+ - Alternative predictions provided for low-confidence matches
327
+ """)
328
 
329
+ with gr.Tab("πŸ“Š View Corrections"):
330
+ gr.Markdown("### Review and export saved corrections")
331
+
332
+ btn_refresh = gr.Button("πŸ”„ Refresh Corrections")
333
+ corrections_df = gr.DataFrame(label="Saved Corrections")
334
+ export_status = gr.Textbox(label="Export Status")
335
+ export_file = gr.File(label="πŸ“₯ Download Corrections CSV")
336
+ btn_export = gr.Button("πŸ“€ Export All Corrections")
337
+
338
+ def show_corrections():
339
+ if not corrections_data:
340
+ return pd.DataFrame({"message": ["No corrections yet"]})
341
+ return pd.DataFrame(corrections_data)
342
+
343
+ btn_refresh.click(fn=show_corrections, outputs=corrections_df)
344
+ btn_export.click(fn=export_corrections, outputs=[export_file, export_status])
345
 
346
  # ==================================================
347
+ # πŸ” Gradio UI: Model Comparison Page
348
  # ==================================================
349
+ with gr.Blocks(title="Model Comparison Tool", theme=gr.themes.Soft()) as compare_ui:
350
+ gr.Markdown("# πŸ” Model Comparison Tool")
351
+ gr.Markdown("Compare predictions between your local model and any HuggingFace model.")
352
+
353
+ with gr.Row():
354
+ hf_model_url = gr.Textbox(
355
+ label="HuggingFace Model ID",
356
+ placeholder="e.g., sentence-transformers/all-MiniLM-L6-v2",
357
+ info="Enter the model ID from HuggingFace"
358
+ )
359
+
360
+ file = gr.File(label="Upload test dataset (CSV with 'transaction' column)", file_types=[".csv"])
361
+ compare_btn = gr.Button("πŸ”¬ Compare Models", variant="primary")
362
+
363
+ output_text = gr.Textbox(label="Comparison Summary", lines=3)
364
+ output_table = gr.DataFrame(label="Detailed Comparison Results")
365
+
366
+ compare_btn.click(
367
+ fn=compare_models_fixed,
368
+ inputs=[hf_model_url, file],
369
+ outputs=[output_text, output_table]
370
+ )
371
 
372
  # ==================================================
373
  # 🌐 Mount Gradio inside FastAPI
374
  # ==================================================
375
+ app = gr.mount_gradio_app(app, gradio_ui, path="/")
376
  app = gr.mount_gradio_app(app, compare_ui, path="/compare")
377
 
378
  # ==================================================
 
381
  class TransactionsRequest(BaseModel):
382
  transactions: List[str]
383
 
 
 
 
 
384
  @app.post("/map_categories")
385
  def map_categories(request: TransactionsRequest):
386
  results = []
387
  for text in request.transactions:
388
+ cat1, cat2, score = classify_single(text)
389
  results.append({
390
  "input_text": text,
391
  "best_Cat1": cat1,
 
394
  })
395
  return {"matches": results}
396
 
397
+ @app.get("/corrections")
398
+ def get_corrections():
399
+ """API endpoint to retrieve all corrections"""
400
+ return {"corrections": corrections_data, "count": len(corrections_data)}