Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import pandas as pd | |
| import gradio as gr | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.feature_extraction.text import CountVectorizer | |
| from sklearn.tree import DecisionTreeClassifier | |
| from sklearn.metrics import accuracy_score, classification_report | |
| # ----------------------------- | |
| # Helpers | |
| # ----------------------------- | |
| def _guess_column(df: pd.DataFrame, candidates): | |
| """Find the first matching column name (case-insensitive) from candidates.""" | |
| cols_lower = {c.lower(): c for c in df.columns} | |
| for cand in candidates: | |
| if cand.lower() in cols_lower: | |
| return cols_lower[cand.lower()] | |
| return None | |
| def _clean_address_series(s: pd.Series) -> pd.Series: | |
| """Light cleaning: ensure string, strip, collapse whitespace.""" | |
| s = s.astype(str).fillna("") | |
| s = s.str.replace(r"\s+", " ", regex=True).str.strip() | |
| return s | |
| def _require(cond: bool, msg: str): | |
| if not cond: | |
| raise ValueError(msg) | |
| # ----------------------------- | |
| # Core: Train / Predict | |
| # ----------------------------- | |
| def train_from_csv( | |
| file_obj, | |
| address_col_name, | |
| label_col_name, | |
| test_size, | |
| max_features, | |
| openai_key, | |
| state, | |
| ): | |
| """ | |
| Train model from uploaded CSV. | |
| - file_obj: gr.File | |
| - address_col_name/label_col_name: optional user overrides | |
| """ | |
| _require(file_obj is not None, "Please upload a CSV file first.") | |
| # Store OpenAI key in-session (optional; not used unless you extend the app) | |
| # Do NOT print it. Keep it in state. | |
| if openai_key: | |
| state["openai_key"] = openai_key | |
| path = file_obj.name | |
| df = pd.read_csv(path) | |
| # Auto-detect columns if not provided | |
| address_col = address_col_name.strip() if address_col_name.strip() else _guess_column( | |
| df, ["address", "Address", "full_address", "Full_Address", "addr", "ADDR"] | |
| ) | |
| label_col = label_col_name.strip() if label_col_name.strip() else _guess_column( | |
| df, ["label", "Label", "category", "Category", "class", "Class", "y", "Y"] | |
| ) | |
| _require(address_col is not None, "Could not find an address column. Provide it in 'Address column name'.") | |
| _require(label_col is not None, "Could not find a label column. Provide it in 'Label column name'.") | |
| _require(address_col in df.columns, f"Address column '{address_col}' not found in CSV.") | |
| _require(label_col in df.columns, f"Label column '{label_col}' not found in CSV.") | |
| # Clean + drop bad rows | |
| df = df[[address_col, label_col]].copy() | |
| df[address_col] = _clean_address_series(df[address_col]) | |
| df[label_col] = df[label_col].astype(str).fillna("").str.strip() | |
| df = df[(df[address_col] != "") & (df[label_col] != "")] | |
| _require(len(df) >= 50, f"Not enough usable rows after cleaning: {len(df)}. Need at least ~50.") | |
| # Vectorize | |
| vectorizer = CountVectorizer(max_features=int(max_features)) | |
| X = vectorizer.fit_transform(df[address_col]) | |
| y = df[label_col] | |
| # Split | |
| X_train, X_val, y_train, y_val = train_test_split( | |
| X, y, test_size=float(test_size), random_state=42, stratify=y if y.nunique() > 1 else None | |
| ) | |
| # Train model | |
| model = DecisionTreeClassifier(random_state=42) | |
| model.fit(X_train, y_train) | |
| # Validate | |
| y_pred = model.predict(X_val) | |
| acc = accuracy_score(y_val, y_pred) | |
| report = classification_report(y_val, y_pred, zero_division=0) | |
| # Save to state for prediction | |
| state["model"] = model | |
| state["vectorizer"] = vectorizer | |
| state["address_col"] = address_col | |
| state["label_col"] = label_col | |
| state["trained"] = True | |
| summary = ( | |
| f"✅ Trained DecisionTreeClassifier\n" | |
| f"- Rows used: {len(df)}\n" | |
| f"- Address col: {address_col}\n" | |
| f"- Label col: {label_col}\n" | |
| f"- Validation accuracy: {acc:.4f}\n\n" | |
| f"Classification report:\n{report}" | |
| ) | |
| return summary, state | |
| def predict_address(address_text, state): | |
| _require(state.get("trained"), "Model not trained yet. Upload CSV and click Train first.") | |
| _require(address_text is not None and address_text.strip() != "", "Enter an address to classify.") | |
| model = state["model"] | |
| vectorizer = state["vectorizer"] | |
| addr = re.sub(r"\s+", " ", address_text.strip()) | |
| X = vectorizer.transform([addr]) | |
| pred = model.predict(X)[0] | |
| # Optional confidence if the model supports predict_proba | |
| conf_str = "" | |
| if hasattr(model, "predict_proba"): | |
| probs = model.predict_proba(X)[0] | |
| classes = model.classes_ | |
| p = float(probs[list(classes).index(pred)]) | |
| conf_str = f" (confidence ~ {p:.3f})" | |
| return f"{pred}{conf_str}" | |
| def clear_model(state): | |
| state.clear() | |
| state.update({"trained": False}) | |
| return "Cleared trained model from session.", state | |
| # ----------------------------- | |
| # UI | |
| # ----------------------------- | |
| with gr.Blocks(title="Address Classifier Trainer") as demo: | |
| state = gr.State({"trained": False}) | |
| gr.Markdown( | |
| """ | |
| # Address Classification Trainer (CSV → Train → Predict) | |
| **Workflow** | |
| 1) Drag & drop your labeled CSV (15k or any size) | |
| 2) Click **Train** | |
| 3) Enter an address and click **Predict** | |
| **Required CSV columns** | |
| - Address column (e.g., `address`) | |
| - Label column (e.g., `label` or `category`) | |
| """ | |
| ) | |
| with gr.Row(): | |
| file_in = gr.File(label="Upload labeled CSV (drag & drop)", file_types=[".csv"]) | |
| openai_key_in = gr.Textbox( | |
| label="OpenAI API Key (optional; stored only in this session)", | |
| type="password", | |
| placeholder="sk-...", | |
| ) | |
| with gr.Row(): | |
| address_col_in = gr.Textbox(label="Address column name (optional override)", placeholder="address") | |
| label_col_in = gr.Textbox(label="Label column name (optional override)", placeholder="label") | |
| with gr.Row(): | |
| test_size_in = gr.Slider(0.05, 0.5, value=0.2, step=0.05, label="Validation split size") | |
| max_features_in = gr.Slider(1000, 50000, value=20000, step=1000, label="Max vocabulary size (CountVectorizer)") | |
| with gr.Row(): | |
| train_btn = gr.Button("Train", variant="primary") | |
| clear_btn = gr.Button("Clear model") | |
| train_out = gr.Textbox(label="Training output", lines=18) | |
| gr.Markdown("## Test a single address") | |
| with gr.Row(): | |
| address_in = gr.Textbox(label="Address input", placeholder="123 Main St, Baltimore, MD 21201", lines=2) | |
| predict_btn = gr.Button("Predict", variant="primary") | |
| pred_out = gr.Textbox(label="Prediction", lines=2) | |
| # Wire actions | |
| train_btn.click( | |
| fn=train_from_csv, | |
| inputs=[file_in, address_col_in, label_col_in, test_size_in, max_features_in, openai_key_in, state], | |
| outputs=[train_out, state], | |
| ) | |
| predict_btn.click( | |
| fn=predict_address, | |
| inputs=[address_in, state], | |
| outputs=[pred_out], | |
| ) | |
| clear_btn.click( | |
| fn=clear_model, | |
| inputs=[state], | |
| outputs=[train_out, state], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |