yassine123Z commited on
Commit
0181d6a
·
verified ·
1 Parent(s): b5d61af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -76
app.py CHANGED
@@ -1,20 +1,19 @@
1
- # app.py
2
- from fastapi import FastAPI
3
- from pydantic import BaseModel
4
- from typing import List
5
  import pandas as pd
6
  from setfit import SetFitModel
7
  from sentence_transformers import util
8
  import torch
9
 
10
- app = FastAPI()
 
 
 
 
 
11
 
12
- # Load your trained model once at startup
13
- model = SetFitModel.from_pretrained(
14
- "HEN10/setfit-particular-transaction-solon-embeddings-labels-large-kaggle-automatisation-v1"
15
- )
16
 
17
- # Dummy reference categories (replace with your real categories or load CSV)
18
  ref_data = pd.DataFrame({
19
  "Cat1EN": [
20
  "Purchase of goods","Purchase of goods","Purchase of goods","Purchase of goods",
@@ -28,28 +27,14 @@ ref_data = pd.DataFrame({
28
  "Food & beverages","Food & beverages","Food & beverages","Food & beverages",
29
  "Food & beverages","Food & beverages","Food & beverages","Food & beverages",
30
  "Heating and air conditioning","Heating and air conditioning","Fuels","Fuels","Fuels","Fuels",
31
- "Fuels","Fuels",
32
-
33
- "Mobility (freight)","Mobility (freight)","Mobility (freight)","Mobility (freight)",
34
- "Mobility (freight)",
35
-
36
- "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)",
37
  "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)","Mobility (passengers)",
38
  "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)","Mobility (passengers)",
39
-
40
-
41
-
42
-
43
- "Process and fugitive emissions","Process and fugitive emissions",
44
- "Process and fugitive emissions",
45
-
46
- "Waste treatment","Waste treatment","Waste treatment",
47
  "Waste treatment","Waste treatment","Waste treatment","Waste treatment","Waste treatment",
48
- "Waste treatment","Waste treatment","Waste treatment","Waste treatment",
49
-
50
-
51
-
52
-
53
  "Use of electricity","Use of electricity","Use of electricity"
54
  ],
55
  "Cat2EN": [
@@ -64,22 +49,13 @@ ref_data = pd.DataFrame({
64
  "Condiments","Desserts","Fruits and vegetables","Fats and oils","Prepared / cooked meals",
65
  "Animal products","Cereal products","Dairy products","Heat and steam","Air conditioning and refrigeration",
66
  "Fossil fuels","Mobile fossil fuels","Organic fuels","Gaseous fossil fuels","Liquid fossil fuels",
67
- "Solid fossil fuels",
68
-
69
- "Air transport","Ship transport","Truck transport","Combined transport",
70
- "Train transport",
71
-
72
- "Air transport","Coach / Urban bus","Ship transport","Combined transport",
73
  "E-Bike","Accommodation / Events","Soft mobility","Motorcycle / Scooter","Train transport",
74
- "Public transport","Car",
75
-
76
- "Agriculture","Global warming potential","Industrial processes",
77
-
78
  "Commercial and industrial","Wastewater","Electrical equipment","Households and similar",
79
  "Metal","Organic materials","Paper and cardboard","Batteries and accumulators","Plastics",
80
- "Fugitive process emissions","Textiles","Glass",
81
-
82
- "Electricity for electric vehicles","Renewables","Standard"
83
  ],
84
  "DescriptionCat2EN": [
85
  "Goods purchase - sports","Goods purchase - buildings","Goods purchase - office items","Goods purchase - water",
@@ -93,51 +69,56 @@ ref_data = pd.DataFrame({
93
  "Food condiments","Food desserts","Food fruits & vegetables","Food fats & oils","Prepared meals",
94
  "Animal-based food","Cereal-based food","Dairy products","Heating - heat & steam","Heating - cooling/refrigeration",
95
  "Fuel - fossil","Fuel - mobile fossil","Fuel - organic","Fuel - gaseous","Fuel - liquid","Fuel - solid",
96
-
97
  "Freight transport - air","Freight transport - ship","Freight transport - truck","Freight transport - combined",
98
- "Freight transport - train",
99
-
100
- "Passenger transport - air","Passenger transport - bus","Passenger transport - ship",
101
  "Passenger transport - combined","Passenger transport - e-bike","Passenger transport - accommodation/events",
102
  "Passenger transport - soft mobility","Passenger transport - scooter/motorbike","Passenger transport - train",
103
- "Passenger transport - public","Passenger transport - car",
104
-
105
- "Emissions - agriculture","Emissions - warming potential",
106
- "Emissions - industry",
107
-
108
- "Waste - commercial/industrial","Waste - wastewater","Waste - electricals",
109
  "Waste - households","Waste - metals","Waste - organics","Waste - paper","Waste - batteries",
110
  "Waste - plastics","Waste - fugitive","Waste - textiles","Waste - glass",
111
-
112
  "Electricity - EVs","Electricity - renewables","Electricity - standard"
113
  ]
114
  })
115
 
116
-
117
-
118
  ref_data["combined"] = ref_data[["Cat1EN", "Cat2EN", "DescriptionCat2EN"]].agg(" ".join, axis=1)
119
  ref_embeddings = model.encode(ref_data["combined"].tolist())
120
 
121
- # Root endpoint so Hugging Face doesn’t show "Not Found"
122
- @app.get("/")
123
- def read_root():
124
- return {"status": "ok", "message": "Category mapping API is running. Use POST /map_categories"}
 
 
 
 
 
 
 
125
 
126
- # Define request schema
127
- class TransactionsRequest(BaseModel):
128
- transactions: List[str]
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- @app.post("/map_categories")
131
- def map_categories(request: TransactionsRequest):
132
- results = []
133
- for text in request.transactions:
134
- trans_emb = model.encode([text])[0]
135
- scores = util.pytorch_cos_sim(torch.tensor(trans_emb), torch.tensor(ref_embeddings)).flatten()
136
- best_idx = scores.argmax().item()
137
- results.append({
138
- "input_text": text,
139
- "best_Cat1": ref_data.iloc[best_idx]["Cat1EN"],
140
- "best_Cat2": ref_data.iloc[best_idx]["Cat2EN"],
141
- "similarity": float(scores[best_idx])
142
- })
143
- return {"matches": results}
 
1
+ import streamlit as st
 
 
 
2
  import pandas as pd
3
  from setfit import SetFitModel
4
  from sentence_transformers import util
5
  import torch
6
 
7
+ # Load model once
8
+ @st.cache_resource
9
+ def load_model():
10
+ return SetFitModel.from_pretrained(
11
+ "HEN10/setfit-particular-transaction-solon-embeddings-labels-large-kaggle-automatisation-v1"
12
+ )
13
 
14
+ model = load_model()
 
 
 
15
 
16
+ # Load reference categories
17
  ref_data = pd.DataFrame({
18
  "Cat1EN": [
19
  "Purchase of goods","Purchase of goods","Purchase of goods","Purchase of goods",
 
27
  "Food & beverages","Food & beverages","Food & beverages","Food & beverages",
28
  "Food & beverages","Food & beverages","Food & beverages","Food & beverages",
29
  "Heating and air conditioning","Heating and air conditioning","Fuels","Fuels","Fuels","Fuels",
30
+ "Fuels","Fuels","Mobility (freight)","Mobility (freight)","Mobility (freight)","Mobility (freight)",
31
+ "Mobility (freight)","Mobility (passengers)","Mobility (passengers)","Mobility (passengers)",
 
 
 
 
32
  "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)","Mobility (passengers)",
33
  "Mobility (passengers)","Mobility (passengers)","Mobility (passengers)","Mobility (passengers)",
34
+ "Mobility (passengers)","Process and fugitive emissions","Process and fugitive emissions",
35
+ "Process and fugitive emissions","Waste treatment","Waste treatment","Waste treatment",
36
+ "Waste treatment","Waste treatment","Waste treatment","Waste treatment","Waste treatment",
 
 
 
 
 
37
  "Waste treatment","Waste treatment","Waste treatment","Waste treatment","Waste treatment",
 
 
 
 
 
38
  "Use of electricity","Use of electricity","Use of electricity"
39
  ],
40
  "Cat2EN": [
 
49
  "Condiments","Desserts","Fruits and vegetables","Fats and oils","Prepared / cooked meals",
50
  "Animal products","Cereal products","Dairy products","Heat and steam","Air conditioning and refrigeration",
51
  "Fossil fuels","Mobile fossil fuels","Organic fuels","Gaseous fossil fuels","Liquid fossil fuels",
52
+ "Solid fossil fuels","Air transport","Ship transport","Truck transport","Combined transport",
53
+ "Train transport","Air transport","Coach / Urban bus","Ship transport","Combined transport",
 
 
 
 
54
  "E-Bike","Accommodation / Events","Soft mobility","Motorcycle / Scooter","Train transport",
55
+ "Public transport","Car","Agriculture","Global warming potential","Industrial processes",
 
 
 
56
  "Commercial and industrial","Wastewater","Electrical equipment","Households and similar",
57
  "Metal","Organic materials","Paper and cardboard","Batteries and accumulators","Plastics",
58
+ "Fugitive process emissions","Textiles","Glass","Electricity for electric vehicles","Renewables","Standard"
 
 
59
  ],
60
  "DescriptionCat2EN": [
61
  "Goods purchase - sports","Goods purchase - buildings","Goods purchase - office items","Goods purchase - water",
 
69
  "Food condiments","Food desserts","Food fruits & vegetables","Food fats & oils","Prepared meals",
70
  "Animal-based food","Cereal-based food","Dairy products","Heating - heat & steam","Heating - cooling/refrigeration",
71
  "Fuel - fossil","Fuel - mobile fossil","Fuel - organic","Fuel - gaseous","Fuel - liquid","Fuel - solid",
 
72
  "Freight transport - air","Freight transport - ship","Freight transport - truck","Freight transport - combined",
73
+ "Freight transport - train","Passenger transport - air","Passenger transport - bus","Passenger transport - ship",
 
 
74
  "Passenger transport - combined","Passenger transport - e-bike","Passenger transport - accommodation/events",
75
  "Passenger transport - soft mobility","Passenger transport - scooter/motorbike","Passenger transport - train",
76
+ "Passenger transport - public","Passenger transport - car","Emissions - agriculture","Emissions - warming potential",
77
+ "Emissions - industry","Waste - commercial/industrial","Waste - wastewater","Waste - electricals",
 
 
 
 
78
  "Waste - households","Waste - metals","Waste - organics","Waste - paper","Waste - batteries",
79
  "Waste - plastics","Waste - fugitive","Waste - textiles","Waste - glass",
 
80
  "Electricity - EVs","Electricity - renewables","Electricity - standard"
81
  ]
82
  })
83
 
84
+ # Precompute embeddings
 
85
  ref_data["combined"] = ref_data[["Cat1EN", "Cat2EN", "DescriptionCat2EN"]].agg(" ".join, axis=1)
86
  ref_embeddings = model.encode(ref_data["combined"].tolist())
87
 
88
+ # Streamlit UI
89
+ st.title("📊 Transaction Category Mapper")
90
+ st.write("Upload a CSV file with a column of transactions, and the app will map them to categories.")
91
+
92
+ uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
93
+
94
+ if uploaded_file:
95
+ df = pd.read_csv(uploaded_file)
96
+
97
+ # Let user choose which column to map
98
+ col_to_use = st.selectbox("Select the column containing transactions:", df.columns)
99
 
100
+ if st.button("Run Mapping"):
101
+ results = []
102
+ for text in df[col_to_use].dropna().tolist():
103
+ trans_emb = model.encode([text])[0]
104
+ scores = util.pytorch_cos_sim(torch.tensor(trans_emb), torch.tensor(ref_embeddings)).flatten()
105
+ best_idx = scores.argmax().item()
106
+ results.append({
107
+ "input_text": text,
108
+ "best_Cat1": ref_data.iloc[best_idx]["Cat1EN"],
109
+ "best_Cat2": ref_data.iloc[best_idx]["Cat2EN"],
110
+ "similarity": float(scores[best_idx])
111
+ })
112
+
113
+ results_df = pd.DataFrame(results)
114
+ st.success("✅ Mapping completed!")
115
+ st.dataframe(results_df)
116
 
117
+ # Option to download
118
+ csv = results_df.to_csv(index=False).encode("utf-8")
119
+ st.download_button(
120
+ label="📥 Download results as CSV",
121
+ data=csv,
122
+ file_name="mapped_transactions.csv",
123
+ mime="text/csv"
124
+ )