Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,8 @@ from setfit import SetFitModel
|
|
| 7 |
from sentence_transformers import util
|
| 8 |
import torch
|
| 9 |
import gradio as gr
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# ==================================================
|
| 12 |
# 🚀 Initialize FastAPI
|
|
@@ -23,15 +25,18 @@ model = SetFitModel.from_pretrained(
|
|
| 23 |
# ==================================================
|
| 24 |
# 📘 Load Reference Categories
|
| 25 |
# ==================================================
|
| 26 |
-
ref_data = pd.DataFrame({
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
| 35 |
ref_data["combined"] = ref_data[["Cat1EN", "Cat2EN", "DescriptionCat2EN"]].agg(" ".join, axis=1)
|
| 36 |
ref_embeddings = model.encode(ref_data["combined"].tolist())
|
| 37 |
|
|
@@ -48,25 +53,68 @@ def classify_transaction(text):
|
|
| 48 |
return cat1, cat2, score
|
| 49 |
|
| 50 |
# ==================================================
|
| 51 |
-
#
|
| 52 |
# ==================================================
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
# Mount Gradio inside FastAPI at /ui
|
| 66 |
app = gr.mount_gradio_app(app, gradio_ui, path="/ui")
|
| 67 |
|
| 68 |
# ==================================================
|
| 69 |
-
# 🧾 API Endpoints
|
| 70 |
# ==================================================
|
| 71 |
class TransactionsRequest(BaseModel):
|
| 72 |
transactions: List[str]
|
|
@@ -79,43 +127,11 @@ def read_root():
|
|
| 79 |
def map_categories(request: TransactionsRequest):
|
| 80 |
results = []
|
| 81 |
for text in request.transactions:
|
| 82 |
-
|
| 83 |
-
scores = util.pytorch_cos_sim(torch.tensor(trans_emb), torch.tensor(ref_embeddings)).flatten()
|
| 84 |
-
best_idx = scores.argmax().item()
|
| 85 |
results.append({
|
| 86 |
"input_text": text,
|
| 87 |
-
"best_Cat1":
|
| 88 |
-
"best_Cat2":
|
| 89 |
-
"similarity":
|
| 90 |
})
|
| 91 |
return {"matches": results}
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
feedback_data = "feedback.csv"
|
| 95 |
-
|
| 96 |
-
@app.post("/feedback/")
|
| 97 |
-
def submit_feedback(text: str, predicted_label: str, correct_label: str):
|
| 98 |
-
df = pd.DataFrame([[text, predicted_label, correct_label]],
|
| 99 |
-
columns=["text", "predicted_label", "correct_label"])
|
| 100 |
-
df.to_csv(feedback_data, mode='a', header=False, index=False)
|
| 101 |
-
return {"message": "Feedback saved successfully"}
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
@app.post("/map_categories_csv/")
|
| 105 |
-
async def map_categories_csv(file: UploadFile = File(...)):
|
| 106 |
-
df = pd.read_csv(file.file)
|
| 107 |
-
results = []
|
| 108 |
-
for text in df['transaction']:
|
| 109 |
-
trans_emb = model.encode([text])[0]
|
| 110 |
-
scores = util.pytorch_cos_sim(torch.tensor(trans_emb), torch.tensor(ref_embeddings)).flatten()
|
| 111 |
-
best_idx = scores.argmax().item()
|
| 112 |
-
results.append({
|
| 113 |
-
"input_text": text,
|
| 114 |
-
"best_Cat1": ref_data.iloc[best_idx]["Cat1EN"],
|
| 115 |
-
"best_Cat2": ref_data.iloc[best_idx]["Cat2EN"],
|
| 116 |
-
"similarity": float(scores[best_idx])
|
| 117 |
-
})
|
| 118 |
-
result_df = pd.DataFrame(results)
|
| 119 |
-
output_file = "results.csv"
|
| 120 |
-
result_df.to_csv(output_file, index=False)
|
| 121 |
-
return FileResponse(output_file, media_type='text/csv', filename="matched_results.csv")
|
|
|
|
| 7 |
from sentence_transformers import util
|
| 8 |
import torch
|
| 9 |
import gradio as gr
|
| 10 |
+
import tempfile
|
| 11 |
+
import os
|
| 12 |
|
| 13 |
# ==================================================
|
| 14 |
# 🚀 Initialize FastAPI
|
|
|
|
| 25 |
# ==================================================
|
| 26 |
# 📘 Load Reference Categories
|
| 27 |
# ==================================================
|
| 28 |
+
ref_data = pd.DataFrame({
|
| 29 |
+
"Cat1EN": ["Purchase of goods", "Mobility (passengers)", "Waste treatment", "Use of electricity"],
|
| 30 |
+
"Cat2EN": ["Office supplies", "Air transport", "Wastewater", "Renewables"],
|
| 31 |
+
"DescriptionCat2EN": [
|
| 32 |
+
"Goods purchase - office items",
|
| 33 |
+
"Passenger transport - air",
|
| 34 |
+
"Waste - wastewater",
|
| 35 |
+
"Electricity - renewables"
|
| 36 |
+
]
|
| 37 |
+
})
|
| 38 |
+
|
| 39 |
+
# Combine all category info into a single string for embeddings
|
| 40 |
ref_data["combined"] = ref_data[["Cat1EN", "Cat2EN", "DescriptionCat2EN"]].agg(" ".join, axis=1)
|
| 41 |
ref_embeddings = model.encode(ref_data["combined"].tolist())
|
| 42 |
|
|
|
|
| 53 |
return cat1, cat2, score
|
| 54 |
|
| 55 |
# ==================================================
|
| 56 |
+
# 📂 CSV Mapping Function
|
| 57 |
# ==================================================
|
| 58 |
+
def map_csv(file):
|
| 59 |
+
df = pd.read_csv(file.name)
|
| 60 |
+
if "transaction" not in df.columns:
|
| 61 |
+
return "Error: Missing column 'transaction'. Please include it in your CSV.", None
|
| 62 |
+
|
| 63 |
+
results = []
|
| 64 |
+
for text in df["transaction"]:
|
| 65 |
+
trans_emb = model.encode([text])[0]
|
| 66 |
+
scores = util.pytorch_cos_sim(torch.tensor(trans_emb), torch.tensor(ref_embeddings)).flatten()
|
| 67 |
+
best_idx = scores.argmax().item()
|
| 68 |
+
results.append({
|
| 69 |
+
"transaction": text,
|
| 70 |
+
"Predicted Category 1": ref_data.iloc[best_idx]["Cat1EN"],
|
| 71 |
+
"Predicted Category 2": ref_data.iloc[best_idx]["Cat2EN"],
|
| 72 |
+
"Similarity Score": float(scores[best_idx])
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
result_df = pd.DataFrame(results)
|
| 76 |
+
|
| 77 |
+
# Save to temporary file for download
|
| 78 |
+
tmp_dir = tempfile.mkdtemp()
|
| 79 |
+
output_path = os.path.join(tmp_dir, "matched_results.csv")
|
| 80 |
+
result_df.to_csv(output_path, index=False)
|
| 81 |
+
|
| 82 |
+
return result_df, output_path
|
| 83 |
+
|
| 84 |
+
# ==================================================
|
| 85 |
+
# 🖥️ Gradio Interface with Upload + Download
|
| 86 |
+
# ==================================================
|
| 87 |
+
with gr.Blocks(title="Transaction Category Classifier") as gradio_ui:
|
| 88 |
+
gr.Markdown("## 🧾 Transaction Category Classifier")
|
| 89 |
+
gr.Markdown("Enter a transaction manually or upload a CSV file to classify multiple transactions.")
|
| 90 |
+
|
| 91 |
+
with gr.Tab("🔹 Single Transaction"):
|
| 92 |
+
text_input = gr.Textbox(label="Transaction Description", placeholder="e.g., going to Barcelona using plane")
|
| 93 |
+
btn_submit = gr.Button("Submit")
|
| 94 |
+
cat1_out = gr.Label(label="Predicted Category 1")
|
| 95 |
+
cat2_out = gr.Label(label="Predicted Category 2")
|
| 96 |
+
score_out = gr.Number(label="Similarity Score")
|
| 97 |
+
btn_submit.click(fn=classify_transaction, inputs=text_input, outputs=[cat1_out, cat2_out, score_out])
|
| 98 |
+
|
| 99 |
+
with gr.Tab("📂 Batch CSV Upload"):
|
| 100 |
+
csv_input = gr.File(label="Upload CSV file with 'transaction' column", file_types=[".csv"])
|
| 101 |
+
btn_process = gr.Button("Process CSV")
|
| 102 |
+
csv_output = gr.DataFrame(label="Matched Results")
|
| 103 |
+
download_file = gr.File(label="Download Results CSV")
|
| 104 |
+
|
| 105 |
+
def process_and_return(file):
|
| 106 |
+
df, output_path = map_csv(file)
|
| 107 |
+
if isinstance(df, str):
|
| 108 |
+
return None, None
|
| 109 |
+
return df, output_path
|
| 110 |
+
|
| 111 |
+
btn_process.click(fn=process_and_return, inputs=csv_input, outputs=[csv_output, download_file])
|
| 112 |
|
| 113 |
# Mount Gradio inside FastAPI at /ui
|
| 114 |
app = gr.mount_gradio_app(app, gradio_ui, path="/ui")
|
| 115 |
|
| 116 |
# ==================================================
|
| 117 |
+
# 🧾 REST API Endpoints
|
| 118 |
# ==================================================
|
| 119 |
class TransactionsRequest(BaseModel):
|
| 120 |
transactions: List[str]
|
|
|
|
| 127 |
def map_categories(request: TransactionsRequest):
|
| 128 |
results = []
|
| 129 |
for text in request.transactions:
|
| 130 |
+
cat1, cat2, score = classify_transaction(text)
|
|
|
|
|
|
|
| 131 |
results.append({
|
| 132 |
"input_text": text,
|
| 133 |
+
"best_Cat1": cat1,
|
| 134 |
+
"best_Cat2": cat2,
|
| 135 |
+
"similarity": score
|
| 136 |
})
|
| 137 |
return {"matches": results}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|