Spaces:
Sleeping
Sleeping
File size: 7,060 Bytes
38d06f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
import os
import re
import pandas as pd
import gradio as gr
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report
# -----------------------------
# Helpers
# -----------------------------
def _guess_column(df: pd.DataFrame, candidates):
"""Find the first matching column name (case-insensitive) from candidates."""
cols_lower = {c.lower(): c for c in df.columns}
for cand in candidates:
if cand.lower() in cols_lower:
return cols_lower[cand.lower()]
return None
def _clean_address_series(s: pd.Series) -> pd.Series:
"""Light cleaning: ensure string, strip, collapse whitespace."""
s = s.astype(str).fillna("")
s = s.str.replace(r"\s+", " ", regex=True).str.strip()
return s
def _require(cond: bool, msg: str):
if not cond:
raise ValueError(msg)
# -----------------------------
# Core: Train / Predict
# -----------------------------
def train_from_csv(
file_obj,
address_col_name,
label_col_name,
test_size,
max_features,
openai_key,
state,
):
"""
Train model from uploaded CSV.
- file_obj: gr.File
- address_col_name/label_col_name: optional user overrides
"""
_require(file_obj is not None, "Please upload a CSV file first.")
# Store OpenAI key in-session (optional; not used unless you extend the app)
# Do NOT print it. Keep it in state.
if openai_key:
state["openai_key"] = openai_key
path = file_obj.name
df = pd.read_csv(path)
# Auto-detect columns if not provided
address_col = address_col_name.strip() if address_col_name.strip() else _guess_column(
df, ["address", "Address", "full_address", "Full_Address", "addr", "ADDR"]
)
label_col = label_col_name.strip() if label_col_name.strip() else _guess_column(
df, ["label", "Label", "category", "Category", "class", "Class", "y", "Y"]
)
_require(address_col is not None, "Could not find an address column. Provide it in 'Address column name'.")
_require(label_col is not None, "Could not find a label column. Provide it in 'Label column name'.")
_require(address_col in df.columns, f"Address column '{address_col}' not found in CSV.")
_require(label_col in df.columns, f"Label column '{label_col}' not found in CSV.")
# Clean + drop bad rows
df = df[[address_col, label_col]].copy()
df[address_col] = _clean_address_series(df[address_col])
df[label_col] = df[label_col].astype(str).fillna("").str.strip()
df = df[(df[address_col] != "") & (df[label_col] != "")]
_require(len(df) >= 50, f"Not enough usable rows after cleaning: {len(df)}. Need at least ~50.")
# Vectorize
vectorizer = CountVectorizer(max_features=int(max_features))
X = vectorizer.fit_transform(df[address_col])
y = df[label_col]
# Split
X_train, X_val, y_train, y_val = train_test_split(
X, y, test_size=float(test_size), random_state=42, stratify=y if y.nunique() > 1 else None
)
# Train model
model = DecisionTreeClassifier(random_state=42)
model.fit(X_train, y_train)
# Validate
y_pred = model.predict(X_val)
acc = accuracy_score(y_val, y_pred)
report = classification_report(y_val, y_pred, zero_division=0)
# Save to state for prediction
state["model"] = model
state["vectorizer"] = vectorizer
state["address_col"] = address_col
state["label_col"] = label_col
state["trained"] = True
summary = (
f"✅ Trained DecisionTreeClassifier\n"
f"- Rows used: {len(df)}\n"
f"- Address col: {address_col}\n"
f"- Label col: {label_col}\n"
f"- Validation accuracy: {acc:.4f}\n\n"
f"Classification report:\n{report}"
)
return summary, state
def predict_address(address_text, state):
_require(state.get("trained"), "Model not trained yet. Upload CSV and click Train first.")
_require(address_text is not None and address_text.strip() != "", "Enter an address to classify.")
model = state["model"]
vectorizer = state["vectorizer"]
addr = re.sub(r"\s+", " ", address_text.strip())
X = vectorizer.transform([addr])
pred = model.predict(X)[0]
# Optional confidence if the model supports predict_proba
conf_str = ""
if hasattr(model, "predict_proba"):
probs = model.predict_proba(X)[0]
classes = model.classes_
p = float(probs[list(classes).index(pred)])
conf_str = f" (confidence ~ {p:.3f})"
return f"{pred}{conf_str}"
def clear_model(state):
state.clear()
state.update({"trained": False})
return "Cleared trained model from session.", state
# -----------------------------
# UI
# -----------------------------
with gr.Blocks(title="Address Classifier Trainer") as demo:
state = gr.State({"trained": False})
gr.Markdown(
"""
# Address Classification Trainer (CSV → Train → Predict)
**Workflow**
1) Drag & drop your labeled CSV (15k or any size)
2) Click **Train**
3) Enter an address and click **Predict**
**Required CSV columns**
- Address column (e.g., `address`)
- Label column (e.g., `label` or `category`)
"""
)
with gr.Row():
file_in = gr.File(label="Upload labeled CSV (drag & drop)", file_types=[".csv"])
openai_key_in = gr.Textbox(
label="OpenAI API Key (optional; stored only in this session)",
type="password",
placeholder="sk-...",
)
with gr.Row():
address_col_in = gr.Textbox(label="Address column name (optional override)", placeholder="address")
label_col_in = gr.Textbox(label="Label column name (optional override)", placeholder="label")
with gr.Row():
test_size_in = gr.Slider(0.05, 0.5, value=0.2, step=0.05, label="Validation split size")
max_features_in = gr.Slider(1000, 50000, value=20000, step=1000, label="Max vocabulary size (CountVectorizer)")
with gr.Row():
train_btn = gr.Button("Train", variant="primary")
clear_btn = gr.Button("Clear model")
train_out = gr.Textbox(label="Training output", lines=18)
gr.Markdown("## Test a single address")
with gr.Row():
address_in = gr.Textbox(label="Address input", placeholder="123 Main St, Baltimore, MD 21201", lines=2)
predict_btn = gr.Button("Predict", variant="primary")
pred_out = gr.Textbox(label="Prediction", lines=2)
# Wire actions
train_btn.click(
fn=train_from_csv,
inputs=[file_in, address_col_in, label_col_in, test_size_in, max_features_in, openai_key_in, state],
outputs=[train_out, state],
)
predict_btn.click(
fn=predict_address,
inputs=[address_in, state],
outputs=[pred_out],
)
clear_btn.click(
fn=clear_model,
inputs=[state],
outputs=[train_out, state],
)
if __name__ == "__main__":
demo.launch()
|