import os import re import pandas as pd import gradio as gr from sklearn.model_selection import train_test_split from sklearn.feature_extraction.text import CountVectorizer from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import accuracy_score, classification_report # ----------------------------- # Helpers # ----------------------------- def _guess_column(df: pd.DataFrame, candidates): """Find the first matching column name (case-insensitive) from candidates.""" cols_lower = {c.lower(): c for c in df.columns} for cand in candidates: if cand.lower() in cols_lower: return cols_lower[cand.lower()] return None def _clean_address_series(s: pd.Series) -> pd.Series: """Light cleaning: ensure string, strip, collapse whitespace.""" s = s.astype(str).fillna("") s = s.str.replace(r"\s+", " ", regex=True).str.strip() return s def _require(cond: bool, msg: str): if not cond: raise ValueError(msg) # ----------------------------- # Core: Train / Predict # ----------------------------- def train_from_csv( file_obj, address_col_name, label_col_name, test_size, max_features, openai_key, state, ): """ Train model from uploaded CSV. - file_obj: gr.File - address_col_name/label_col_name: optional user overrides """ _require(file_obj is not None, "Please upload a CSV file first.") # Store OpenAI key in-session (optional; not used unless you extend the app) # Do NOT print it. Keep it in state. if openai_key: state["openai_key"] = openai_key path = file_obj.name df = pd.read_csv(path) # Auto-detect columns if not provided address_col = address_col_name.strip() if address_col_name.strip() else _guess_column( df, ["address", "Address", "full_address", "Full_Address", "addr", "ADDR"] ) label_col = label_col_name.strip() if label_col_name.strip() else _guess_column( df, ["label", "Label", "category", "Category", "class", "Class", "y", "Y"] ) _require(address_col is not None, "Could not find an address column. Provide it in 'Address column name'.") _require(label_col is not None, "Could not find a label column. Provide it in 'Label column name'.") _require(address_col in df.columns, f"Address column '{address_col}' not found in CSV.") _require(label_col in df.columns, f"Label column '{label_col}' not found in CSV.") # Clean + drop bad rows df = df[[address_col, label_col]].copy() df[address_col] = _clean_address_series(df[address_col]) df[label_col] = df[label_col].astype(str).fillna("").str.strip() df = df[(df[address_col] != "") & (df[label_col] != "")] _require(len(df) >= 50, f"Not enough usable rows after cleaning: {len(df)}. Need at least ~50.") # Vectorize vectorizer = CountVectorizer(max_features=int(max_features)) X = vectorizer.fit_transform(df[address_col]) y = df[label_col] # Split X_train, X_val, y_train, y_val = train_test_split( X, y, test_size=float(test_size), random_state=42, stratify=y if y.nunique() > 1 else None ) # Train model model = DecisionTreeClassifier(random_state=42) model.fit(X_train, y_train) # Validate y_pred = model.predict(X_val) acc = accuracy_score(y_val, y_pred) report = classification_report(y_val, y_pred, zero_division=0) # Save to state for prediction state["model"] = model state["vectorizer"] = vectorizer state["address_col"] = address_col state["label_col"] = label_col state["trained"] = True summary = ( f"✅ Trained DecisionTreeClassifier\n" f"- Rows used: {len(df)}\n" f"- Address col: {address_col}\n" f"- Label col: {label_col}\n" f"- Validation accuracy: {acc:.4f}\n\n" f"Classification report:\n{report}" ) return summary, state def predict_address(address_text, state): _require(state.get("trained"), "Model not trained yet. Upload CSV and click Train first.") _require(address_text is not None and address_text.strip() != "", "Enter an address to classify.") model = state["model"] vectorizer = state["vectorizer"] addr = re.sub(r"\s+", " ", address_text.strip()) X = vectorizer.transform([addr]) pred = model.predict(X)[0] # Optional confidence if the model supports predict_proba conf_str = "" if hasattr(model, "predict_proba"): probs = model.predict_proba(X)[0] classes = model.classes_ p = float(probs[list(classes).index(pred)]) conf_str = f" (confidence ~ {p:.3f})" return f"{pred}{conf_str}" def clear_model(state): state.clear() state.update({"trained": False}) return "Cleared trained model from session.", state # ----------------------------- # UI # ----------------------------- with gr.Blocks(title="Address Classifier Trainer") as demo: state = gr.State({"trained": False}) gr.Markdown( """ # Address Classification Trainer (CSV → Train → Predict) **Workflow** 1) Drag & drop your labeled CSV (15k or any size) 2) Click **Train** 3) Enter an address and click **Predict** **Required CSV columns** - Address column (e.g., `address`) - Label column (e.g., `label` or `category`) """ ) with gr.Row(): file_in = gr.File(label="Upload labeled CSV (drag & drop)", file_types=[".csv"]) openai_key_in = gr.Textbox( label="OpenAI API Key (optional; stored only in this session)", type="password", placeholder="sk-...", ) with gr.Row(): address_col_in = gr.Textbox(label="Address column name (optional override)", placeholder="address") label_col_in = gr.Textbox(label="Label column name (optional override)", placeholder="label") with gr.Row(): test_size_in = gr.Slider(0.05, 0.5, value=0.2, step=0.05, label="Validation split size") max_features_in = gr.Slider(1000, 50000, value=20000, step=1000, label="Max vocabulary size (CountVectorizer)") with gr.Row(): train_btn = gr.Button("Train", variant="primary") clear_btn = gr.Button("Clear model") train_out = gr.Textbox(label="Training output", lines=18) gr.Markdown("## Test a single address") with gr.Row(): address_in = gr.Textbox(label="Address input", placeholder="123 Main St, Baltimore, MD 21201", lines=2) predict_btn = gr.Button("Predict", variant="primary") pred_out = gr.Textbox(label="Prediction", lines=2) # Wire actions train_btn.click( fn=train_from_csv, inputs=[file_in, address_col_in, label_col_in, test_size_in, max_features_in, openai_key_in, state], outputs=[train_out, state], ) predict_btn.click( fn=predict_address, inputs=[address_in, state], outputs=[pred_out], ) clear_btn.click( fn=clear_model, inputs=[state], outputs=[train_out, state], ) if __name__ == "__main__": demo.launch()