hchevva's picture
Rename addressfinder.py to app.py
6b847bf verified
import os
import re
import pandas as pd
import gradio as gr
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report
# -----------------------------
# Helpers
# -----------------------------
def _guess_column(df: pd.DataFrame, candidates):
"""Find the first matching column name (case-insensitive) from candidates."""
cols_lower = {c.lower(): c for c in df.columns}
for cand in candidates:
if cand.lower() in cols_lower:
return cols_lower[cand.lower()]
return None
def _clean_address_series(s: pd.Series) -> pd.Series:
"""Light cleaning: ensure string, strip, collapse whitespace."""
s = s.astype(str).fillna("")
s = s.str.replace(r"\s+", " ", regex=True).str.strip()
return s
def _require(cond: bool, msg: str):
if not cond:
raise ValueError(msg)
# -----------------------------
# Core: Train / Predict
# -----------------------------
def train_from_csv(
file_obj,
address_col_name,
label_col_name,
test_size,
max_features,
openai_key,
state,
):
"""
Train model from uploaded CSV.
- file_obj: gr.File
- address_col_name/label_col_name: optional user overrides
"""
_require(file_obj is not None, "Please upload a CSV file first.")
# Store OpenAI key in-session (optional; not used unless you extend the app)
# Do NOT print it. Keep it in state.
if openai_key:
state["openai_key"] = openai_key
path = file_obj.name
df = pd.read_csv(path)
# Auto-detect columns if not provided
address_col = address_col_name.strip() if address_col_name.strip() else _guess_column(
df, ["address", "Address", "full_address", "Full_Address", "addr", "ADDR"]
)
label_col = label_col_name.strip() if label_col_name.strip() else _guess_column(
df, ["label", "Label", "category", "Category", "class", "Class", "y", "Y"]
)
_require(address_col is not None, "Could not find an address column. Provide it in 'Address column name'.")
_require(label_col is not None, "Could not find a label column. Provide it in 'Label column name'.")
_require(address_col in df.columns, f"Address column '{address_col}' not found in CSV.")
_require(label_col in df.columns, f"Label column '{label_col}' not found in CSV.")
# Clean + drop bad rows
df = df[[address_col, label_col]].copy()
df[address_col] = _clean_address_series(df[address_col])
df[label_col] = df[label_col].astype(str).fillna("").str.strip()
df = df[(df[address_col] != "") & (df[label_col] != "")]
_require(len(df) >= 50, f"Not enough usable rows after cleaning: {len(df)}. Need at least ~50.")
# Vectorize
vectorizer = CountVectorizer(max_features=int(max_features))
X = vectorizer.fit_transform(df[address_col])
y = df[label_col]
# Split
X_train, X_val, y_train, y_val = train_test_split(
X, y, test_size=float(test_size), random_state=42, stratify=y if y.nunique() > 1 else None
)
# Train model
model = DecisionTreeClassifier(random_state=42)
model.fit(X_train, y_train)
# Validate
y_pred = model.predict(X_val)
acc = accuracy_score(y_val, y_pred)
report = classification_report(y_val, y_pred, zero_division=0)
# Save to state for prediction
state["model"] = model
state["vectorizer"] = vectorizer
state["address_col"] = address_col
state["label_col"] = label_col
state["trained"] = True
summary = (
f"✅ Trained DecisionTreeClassifier\n"
f"- Rows used: {len(df)}\n"
f"- Address col: {address_col}\n"
f"- Label col: {label_col}\n"
f"- Validation accuracy: {acc:.4f}\n\n"
f"Classification report:\n{report}"
)
return summary, state
def predict_address(address_text, state):
_require(state.get("trained"), "Model not trained yet. Upload CSV and click Train first.")
_require(address_text is not None and address_text.strip() != "", "Enter an address to classify.")
model = state["model"]
vectorizer = state["vectorizer"]
addr = re.sub(r"\s+", " ", address_text.strip())
X = vectorizer.transform([addr])
pred = model.predict(X)[0]
# Optional confidence if the model supports predict_proba
conf_str = ""
if hasattr(model, "predict_proba"):
probs = model.predict_proba(X)[0]
classes = model.classes_
p = float(probs[list(classes).index(pred)])
conf_str = f" (confidence ~ {p:.3f})"
return f"{pred}{conf_str}"
def clear_model(state):
state.clear()
state.update({"trained": False})
return "Cleared trained model from session.", state
# -----------------------------
# UI
# -----------------------------
with gr.Blocks(title="Address Classifier Trainer") as demo:
state = gr.State({"trained": False})
gr.Markdown(
"""
# Address Classification Trainer (CSV → Train → Predict)
**Workflow**
1) Drag & drop your labeled CSV (15k or any size)
2) Click **Train**
3) Enter an address and click **Predict**
**Required CSV columns**
- Address column (e.g., `address`)
- Label column (e.g., `label` or `category`)
"""
)
with gr.Row():
file_in = gr.File(label="Upload labeled CSV (drag & drop)", file_types=[".csv"])
openai_key_in = gr.Textbox(
label="OpenAI API Key (optional; stored only in this session)",
type="password",
placeholder="sk-...",
)
with gr.Row():
address_col_in = gr.Textbox(label="Address column name (optional override)", placeholder="address")
label_col_in = gr.Textbox(label="Label column name (optional override)", placeholder="label")
with gr.Row():
test_size_in = gr.Slider(0.05, 0.5, value=0.2, step=0.05, label="Validation split size")
max_features_in = gr.Slider(1000, 50000, value=20000, step=1000, label="Max vocabulary size (CountVectorizer)")
with gr.Row():
train_btn = gr.Button("Train", variant="primary")
clear_btn = gr.Button("Clear model")
train_out = gr.Textbox(label="Training output", lines=18)
gr.Markdown("## Test a single address")
with gr.Row():
address_in = gr.Textbox(label="Address input", placeholder="123 Main St, Baltimore, MD 21201", lines=2)
predict_btn = gr.Button("Predict", variant="primary")
pred_out = gr.Textbox(label="Prediction", lines=2)
# Wire actions
train_btn.click(
fn=train_from_csv,
inputs=[file_in, address_col_in, label_col_in, test_size_in, max_features_in, openai_key_in, state],
outputs=[train_out, state],
)
predict_btn.click(
fn=predict_address,
inputs=[address_in, state],
outputs=[pred_out],
)
clear_btn.click(
fn=clear_model,
inputs=[state],
outputs=[train_out, state],
)
if __name__ == "__main__":
demo.launch()