yassine123Z commited on
Commit
6331c93
·
verified ·
1 Parent(s): f9eaa32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -58
app.py CHANGED
@@ -7,6 +7,8 @@ from setfit import SetFitModel
7
  from sentence_transformers import util
8
  import torch
9
  import gradio as gr
 
 
10
 
11
  # ==================================================
12
  # 🚀 Initialize FastAPI
@@ -23,15 +25,18 @@ model = SetFitModel.from_pretrained(
23
  # ==================================================
24
  # 📘 Load Reference Categories
25
  # ==================================================
26
- ref_data = pd.DataFrame({ "Cat1EN": [ "Purchase of goods","Purchase of goods","Purchase of goods","Purchase of goods", "Purchase of goods","Purchase of goods","Purchase of goods","Purchase of goods", "Purchase of goods","Purchase of goods","Purchase of materials","Purchase of materials", "Purchase of materials","Purchase of materials","Purchase of materials","Purchase of materials", "Purchase of services","Purchase of services","Purchase of services","Purchase of services", "Purchase of services","Purchase of services","Purchase of services","Purchase of services", "Purchase of services","Purchase of services","Purchase of services","Purchase of services", "Purchase of services","Purchase of services","Food & beverages","Food & beverages", "Food & beverages","Food & beverages","Food & beverages","Food & beverages", "Food & beverages","Food & beverages","Food & beverages","Food & beverages", "Heating and air conditioning","Heating and air conditioning","Fuels","Fuels","Fuels","Fuels", "Fuels","Fuels", "Mobility (freight)","Mobility (freight)","Mobility (freight)","Mobility (freight)", "Mobility (freight)", "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)", "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)","Mobility (passengers)", "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)","Mobility (passengers)", "Process and fugitive emissions","Process and fugitive emissions", "Process and fugitive emissions", "Waste treatment","Waste treatment","Waste treatment", "Waste treatment","Waste treatment","Waste treatment","Waste treatment","Waste treatment", "Waste treatment","Waste treatment","Waste treatment","Waste treatment", "Use of electricity","Use of electricity","Use of electricity" ],
27
-
28
- "Cat2EN": [ "Sporting goods","Buildings","Office supplies","Water consumption", "Household appliances","Electrical equipment","Machinery and equipment","Furniture", "Textiles and clothing","Vehicles","Construction materials","Organic materials", "Paper and cardboard","Plastics and rubber","Chemicals","Refrigerants and others", "Equipment rental","Building rental","Furniture rental","Vehicle rental and maintenance", "Information and cultural services","Catering services","Health services","Specialized craft services", "Administrative / consulting services","Cleaning services","IT services","Logistics services", "Marketing / advertising services","Technical services","Alcoholic beverages","Non-alcoholic beverages", "Condiments","Desserts","Fruits and vegetables","Fats and oils","Prepared / cooked meals", "Animal products","Cereal products","Dairy products","Heat and steam","Air conditioning and refrigeration", "Fossil fuels","Mobile fossil fuels","Organic fuels","Gaseous fossil fuels","Liquid fossil fuels", "Solid fossil fuels", "Air transport","Ship transport","Truck transport","Combined transport", "Train transport", "Air transport","Coach / Urban bus","Ship transport","Combined transport", "E-Bike","Accommodation / Events","Soft mobility","Motorcycle / Scooter","Train transport", "Public transport","Car", "Agriculture","Global warming potential","Industrial processes", "Commercial and industrial","Wastewater","Electrical equipment","Households and similar", "Metal","Organic materials","Paper and cardboard","Batteries and accumulators","Plastics", "Fugitive process emissions","Textiles","Glass", "Electricity for electric vehicles","Renewables","Standard" ],
29
-
30
- "DescriptionCat2EN": [ "Goods purchase - sports","Goods purchase - buildings","Goods purchase - office items","Goods purchase - water", "Goods purchase - appliances","Goods purchase - electricals","Goods purchase - machinery","Goods purchase - furniture", "Goods purchase - textiles","Goods purchase - vehicles","Material purchase - construction","Material purchase - organic", "Material purchase - paper","Material purchase - plastics","Material purchase - chemicals","Material purchase - refrigerants", "Service - equipment rental","Service - building rental","Service - furniture rental","Service - vehicles", "Service - info/culture","Service - catering","Service - healthcare","Service - crafts", "Service - admin/consulting","Service - cleaning","Service - IT","Service - logistics", "Service - marketing","Service - technical","Beverages - alcoholic","Beverages - non-alcoholic", "Food condiments","Food desserts","Food fruits & vegetables","Food fats & oils","Prepared meals", "Animal-based food","Cereal-based food","Dairy products","Heating - heat & steam","Heating - cooling/refrigeration", "Fuel - fossil","Fuel - mobile fossil","Fuel - organic","Fuel - gaseous","Fuel - liquid","Fuel - solid", "Freight transport - air","Freight transport - ship","Freight transport - truck","Freight transport - combined", "Freight transport - train", "Passenger transport - air","Passenger transport - bus","Passenger transport - ship", "Passenger transport - combined","Passenger transport - e-bike","Passenger transport - accommodation/events", "Passenger transport - soft mobility","Passenger transport - scooter/motorbike","Passenger transport - train", "Passenger transport - public","Passenger transport - car", "Emissions - agriculture","Emissions - warming potential", "Emissions - industry", "Waste - commercial/industrial","Waste - wastewater","Waste - electricals", "Waste - households","Waste - metals","Waste - organics","Waste - paper","Waste - batteries", "Waste - plastics","Waste - fugitive","Waste - textiles","Waste - glass", "Electricity - EVs","Electricity - renewables","Electricity - standard" ]
31
-
32
- })
33
-
34
- # combine all category info into a single string for embeddings
 
 
 
35
  ref_data["combined"] = ref_data[["Cat1EN", "Cat2EN", "DescriptionCat2EN"]].agg(" ".join, axis=1)
36
  ref_embeddings = model.encode(ref_data["combined"].tolist())
37
 
@@ -48,25 +53,68 @@ def classify_transaction(text):
48
  return cat1, cat2, score
49
 
50
  # ==================================================
51
- # 🖥️ Gradio Interface
52
  # ==================================================
53
- gradio_ui = gr.Interface(
54
- fn=classify_transaction,
55
- inputs=gr.Textbox(lines=3, label="Transaction Description", placeholder="Enter a transaction text..."),
56
- outputs=[
57
- gr.Label(label="Predicted Category 1"),
58
- gr.Label(label="Predicted Category 2"),
59
- gr.Number(label="Similarity Score")
60
- ],
61
- title="Transaction Category Classifier",
62
- description="Enter a transaction description and get the best-matching category using SetFit embeddings.",
63
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # Mount Gradio inside FastAPI at /ui
66
  app = gr.mount_gradio_app(app, gradio_ui, path="/ui")
67
 
68
  # ==================================================
69
- # 🧾 API Endpoints
70
  # ==================================================
71
  class TransactionsRequest(BaseModel):
72
  transactions: List[str]
@@ -79,43 +127,11 @@ def read_root():
79
  def map_categories(request: TransactionsRequest):
80
  results = []
81
  for text in request.transactions:
82
- trans_emb = model.encode([text])[0]
83
- scores = util.pytorch_cos_sim(torch.tensor(trans_emb), torch.tensor(ref_embeddings)).flatten()
84
- best_idx = scores.argmax().item()
85
  results.append({
86
  "input_text": text,
87
- "best_Cat1": ref_data.iloc[best_idx]["Cat1EN"],
88
- "best_Cat2": ref_data.iloc[best_idx]["Cat2EN"],
89
- "similarity": float(scores[best_idx])
90
  })
91
  return {"matches": results}
92
-
93
-
94
- feedback_data = "feedback.csv"
95
-
96
- @app.post("/feedback/")
97
- def submit_feedback(text: str, predicted_label: str, correct_label: str):
98
- df = pd.DataFrame([[text, predicted_label, correct_label]],
99
- columns=["text", "predicted_label", "correct_label"])
100
- df.to_csv(feedback_data, mode='a', header=False, index=False)
101
- return {"message": "Feedback saved successfully"}
102
-
103
-
104
- @app.post("/map_categories_csv/")
105
- async def map_categories_csv(file: UploadFile = File(...)):
106
- df = pd.read_csv(file.file)
107
- results = []
108
- for text in df['transaction']:
109
- trans_emb = model.encode([text])[0]
110
- scores = util.pytorch_cos_sim(torch.tensor(trans_emb), torch.tensor(ref_embeddings)).flatten()
111
- best_idx = scores.argmax().item()
112
- results.append({
113
- "input_text": text,
114
- "best_Cat1": ref_data.iloc[best_idx]["Cat1EN"],
115
- "best_Cat2": ref_data.iloc[best_idx]["Cat2EN"],
116
- "similarity": float(scores[best_idx])
117
- })
118
- result_df = pd.DataFrame(results)
119
- output_file = "results.csv"
120
- result_df.to_csv(output_file, index=False)
121
- return FileResponse(output_file, media_type='text/csv', filename="matched_results.csv")
 
7
  from sentence_transformers import util
8
  import torch
9
  import gradio as gr
10
+ import tempfile
11
+ import os
12
 
13
  # ==================================================
14
  # 🚀 Initialize FastAPI
 
25
  # ==================================================
26
  # 📘 Load Reference Categories
27
  # ==================================================
28
+ ref_data = pd.DataFrame({
29
+ "Cat1EN": ["Purchase of goods", "Mobility (passengers)", "Waste treatment", "Use of electricity"],
30
+ "Cat2EN": ["Office supplies", "Air transport", "Wastewater", "Renewables"],
31
+ "DescriptionCat2EN": [
32
+ "Goods purchase - office items",
33
+ "Passenger transport - air",
34
+ "Waste - wastewater",
35
+ "Electricity - renewables"
36
+ ]
37
+ })
38
+
39
+ # Combine all category info into a single string for embeddings
40
  ref_data["combined"] = ref_data[["Cat1EN", "Cat2EN", "DescriptionCat2EN"]].agg(" ".join, axis=1)
41
  ref_embeddings = model.encode(ref_data["combined"].tolist())
42
 
 
53
  return cat1, cat2, score
54
 
55
  # ==================================================
56
+ # 📂 CSV Mapping Function
57
  # ==================================================
58
+ def map_csv(file):
59
+ df = pd.read_csv(file.name)
60
+ if "transaction" not in df.columns:
61
+ return "Error: Missing column 'transaction'. Please include it in your CSV.", None
62
+
63
+ results = []
64
+ for text in df["transaction"]:
65
+ trans_emb = model.encode([text])[0]
66
+ scores = util.pytorch_cos_sim(torch.tensor(trans_emb), torch.tensor(ref_embeddings)).flatten()
67
+ best_idx = scores.argmax().item()
68
+ results.append({
69
+ "transaction": text,
70
+ "Predicted Category 1": ref_data.iloc[best_idx]["Cat1EN"],
71
+ "Predicted Category 2": ref_data.iloc[best_idx]["Cat2EN"],
72
+ "Similarity Score": float(scores[best_idx])
73
+ })
74
+
75
+ result_df = pd.DataFrame(results)
76
+
77
+ # Save to temporary file for download
78
+ tmp_dir = tempfile.mkdtemp()
79
+ output_path = os.path.join(tmp_dir, "matched_results.csv")
80
+ result_df.to_csv(output_path, index=False)
81
+
82
+ return result_df, output_path
83
+
84
+ # ==================================================
85
+ # 🖥️ Gradio Interface with Upload + Download
86
+ # ==================================================
87
+ with gr.Blocks(title="Transaction Category Classifier") as gradio_ui:
88
+ gr.Markdown("## 🧾 Transaction Category Classifier")
89
+ gr.Markdown("Enter a transaction manually or upload a CSV file to classify multiple transactions.")
90
+
91
+ with gr.Tab("🔹 Single Transaction"):
92
+ text_input = gr.Textbox(label="Transaction Description", placeholder="e.g., going to Barcelona using plane")
93
+ btn_submit = gr.Button("Submit")
94
+ cat1_out = gr.Label(label="Predicted Category 1")
95
+ cat2_out = gr.Label(label="Predicted Category 2")
96
+ score_out = gr.Number(label="Similarity Score")
97
+ btn_submit.click(fn=classify_transaction, inputs=text_input, outputs=[cat1_out, cat2_out, score_out])
98
+
99
+ with gr.Tab("📂 Batch CSV Upload"):
100
+ csv_input = gr.File(label="Upload CSV file with 'transaction' column", file_types=[".csv"])
101
+ btn_process = gr.Button("Process CSV")
102
+ csv_output = gr.DataFrame(label="Matched Results")
103
+ download_file = gr.File(label="Download Results CSV")
104
+
105
+ def process_and_return(file):
106
+ df, output_path = map_csv(file)
107
+ if isinstance(df, str):
108
+ return None, None
109
+ return df, output_path
110
+
111
+ btn_process.click(fn=process_and_return, inputs=csv_input, outputs=[csv_output, download_file])
112
 
113
  # Mount Gradio inside FastAPI at /ui
114
  app = gr.mount_gradio_app(app, gradio_ui, path="/ui")
115
 
116
  # ==================================================
117
+ # 🧾 REST API Endpoints
118
  # ==================================================
119
  class TransactionsRequest(BaseModel):
120
  transactions: List[str]
 
127
  def map_categories(request: TransactionsRequest):
128
  results = []
129
  for text in request.transactions:
130
+ cat1, cat2, score = classify_transaction(text)
 
 
131
  results.append({
132
  "input_text": text,
133
+ "best_Cat1": cat1,
134
+ "best_Cat2": cat2,
135
+ "similarity": score
136
  })
137
  return {"matches": results}