hchevva commited on
Commit
38d06f2
·
verified ·
1 Parent(s): ff30f2d

Create addressfinder.py

Browse files
Files changed (1) hide show
  1. addressfinder.py +219 -0
addressfinder.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+ import gradio as gr
5
+
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.feature_extraction.text import CountVectorizer
8
+ from sklearn.tree import DecisionTreeClassifier
9
+ from sklearn.metrics import accuracy_score, classification_report
10
+
11
+
12
+ # -----------------------------
13
+ # Helpers
14
+ # -----------------------------
15
+ def _guess_column(df: pd.DataFrame, candidates):
16
+ """Find the first matching column name (case-insensitive) from candidates."""
17
+ cols_lower = {c.lower(): c for c in df.columns}
18
+ for cand in candidates:
19
+ if cand.lower() in cols_lower:
20
+ return cols_lower[cand.lower()]
21
+ return None
22
+
23
+ def _clean_address_series(s: pd.Series) -> pd.Series:
24
+ """Light cleaning: ensure string, strip, collapse whitespace."""
25
+ s = s.astype(str).fillna("")
26
+ s = s.str.replace(r"\s+", " ", regex=True).str.strip()
27
+ return s
28
+
29
+ def _require(cond: bool, msg: str):
30
+ if not cond:
31
+ raise ValueError(msg)
32
+
33
+
34
+ # -----------------------------
35
+ # Core: Train / Predict
36
+ # -----------------------------
37
+ def train_from_csv(
38
+ file_obj,
39
+ address_col_name,
40
+ label_col_name,
41
+ test_size,
42
+ max_features,
43
+ openai_key,
44
+ state,
45
+ ):
46
+ """
47
+ Train model from uploaded CSV.
48
+ - file_obj: gr.File
49
+ - address_col_name/label_col_name: optional user overrides
50
+ """
51
+ _require(file_obj is not None, "Please upload a CSV file first.")
52
+
53
+ # Store OpenAI key in-session (optional; not used unless you extend the app)
54
+ # Do NOT print it. Keep it in state.
55
+ if openai_key:
56
+ state["openai_key"] = openai_key
57
+
58
+ path = file_obj.name
59
+ df = pd.read_csv(path)
60
+
61
+ # Auto-detect columns if not provided
62
+ address_col = address_col_name.strip() if address_col_name.strip() else _guess_column(
63
+ df, ["address", "Address", "full_address", "Full_Address", "addr", "ADDR"]
64
+ )
65
+ label_col = label_col_name.strip() if label_col_name.strip() else _guess_column(
66
+ df, ["label", "Label", "category", "Category", "class", "Class", "y", "Y"]
67
+ )
68
+
69
+ _require(address_col is not None, "Could not find an address column. Provide it in 'Address column name'.")
70
+ _require(label_col is not None, "Could not find a label column. Provide it in 'Label column name'.")
71
+
72
+ _require(address_col in df.columns, f"Address column '{address_col}' not found in CSV.")
73
+ _require(label_col in df.columns, f"Label column '{label_col}' not found in CSV.")
74
+
75
+ # Clean + drop bad rows
76
+ df = df[[address_col, label_col]].copy()
77
+ df[address_col] = _clean_address_series(df[address_col])
78
+ df[label_col] = df[label_col].astype(str).fillna("").str.strip()
79
+
80
+ df = df[(df[address_col] != "") & (df[label_col] != "")]
81
+ _require(len(df) >= 50, f"Not enough usable rows after cleaning: {len(df)}. Need at least ~50.")
82
+
83
+ # Vectorize
84
+ vectorizer = CountVectorizer(max_features=int(max_features))
85
+ X = vectorizer.fit_transform(df[address_col])
86
+ y = df[label_col]
87
+
88
+ # Split
89
+ X_train, X_val, y_train, y_val = train_test_split(
90
+ X, y, test_size=float(test_size), random_state=42, stratify=y if y.nunique() > 1 else None
91
+ )
92
+
93
+ # Train model
94
+ model = DecisionTreeClassifier(random_state=42)
95
+ model.fit(X_train, y_train)
96
+
97
+ # Validate
98
+ y_pred = model.predict(X_val)
99
+ acc = accuracy_score(y_val, y_pred)
100
+ report = classification_report(y_val, y_pred, zero_division=0)
101
+
102
+ # Save to state for prediction
103
+ state["model"] = model
104
+ state["vectorizer"] = vectorizer
105
+ state["address_col"] = address_col
106
+ state["label_col"] = label_col
107
+ state["trained"] = True
108
+
109
+ summary = (
110
+ f"✅ Trained DecisionTreeClassifier\n"
111
+ f"- Rows used: {len(df)}\n"
112
+ f"- Address col: {address_col}\n"
113
+ f"- Label col: {label_col}\n"
114
+ f"- Validation accuracy: {acc:.4f}\n\n"
115
+ f"Classification report:\n{report}"
116
+ )
117
+
118
+ return summary, state
119
+
120
+
121
+ def predict_address(address_text, state):
122
+ _require(state.get("trained"), "Model not trained yet. Upload CSV and click Train first.")
123
+ _require(address_text is not None and address_text.strip() != "", "Enter an address to classify.")
124
+
125
+ model = state["model"]
126
+ vectorizer = state["vectorizer"]
127
+
128
+ addr = re.sub(r"\s+", " ", address_text.strip())
129
+ X = vectorizer.transform([addr])
130
+ pred = model.predict(X)[0]
131
+
132
+ # Optional confidence if the model supports predict_proba
133
+ conf_str = ""
134
+ if hasattr(model, "predict_proba"):
135
+ probs = model.predict_proba(X)[0]
136
+ classes = model.classes_
137
+ p = float(probs[list(classes).index(pred)])
138
+ conf_str = f" (confidence ~ {p:.3f})"
139
+
140
+ return f"{pred}{conf_str}"
141
+
142
+
143
+ def clear_model(state):
144
+ state.clear()
145
+ state.update({"trained": False})
146
+ return "Cleared trained model from session.", state
147
+
148
+
149
+ # -----------------------------
150
+ # UI
151
+ # -----------------------------
152
+ with gr.Blocks(title="Address Classifier Trainer") as demo:
153
+ state = gr.State({"trained": False})
154
+
155
+ gr.Markdown(
156
+ """
157
+ # Address Classification Trainer (CSV → Train → Predict)
158
+
159
+ **Workflow**
160
+ 1) Drag & drop your labeled CSV (15k or any size)
161
+ 2) Click **Train**
162
+ 3) Enter an address and click **Predict**
163
+
164
+ **Required CSV columns**
165
+ - Address column (e.g., `address`)
166
+ - Label column (e.g., `label` or `category`)
167
+ """
168
+ )
169
+
170
+ with gr.Row():
171
+ file_in = gr.File(label="Upload labeled CSV (drag & drop)", file_types=[".csv"])
172
+ openai_key_in = gr.Textbox(
173
+ label="OpenAI API Key (optional; stored only in this session)",
174
+ type="password",
175
+ placeholder="sk-...",
176
+ )
177
+
178
+ with gr.Row():
179
+ address_col_in = gr.Textbox(label="Address column name (optional override)", placeholder="address")
180
+ label_col_in = gr.Textbox(label="Label column name (optional override)", placeholder="label")
181
+
182
+ with gr.Row():
183
+ test_size_in = gr.Slider(0.05, 0.5, value=0.2, step=0.05, label="Validation split size")
184
+ max_features_in = gr.Slider(1000, 50000, value=20000, step=1000, label="Max vocabulary size (CountVectorizer)")
185
+
186
+ with gr.Row():
187
+ train_btn = gr.Button("Train", variant="primary")
188
+ clear_btn = gr.Button("Clear model")
189
+
190
+ train_out = gr.Textbox(label="Training output", lines=18)
191
+
192
+ gr.Markdown("## Test a single address")
193
+ with gr.Row():
194
+ address_in = gr.Textbox(label="Address input", placeholder="123 Main St, Baltimore, MD 21201", lines=2)
195
+ predict_btn = gr.Button("Predict", variant="primary")
196
+
197
+ pred_out = gr.Textbox(label="Prediction", lines=2)
198
+
199
+ # Wire actions
200
+ train_btn.click(
201
+ fn=train_from_csv,
202
+ inputs=[file_in, address_col_in, label_col_in, test_size_in, max_features_in, openai_key_in, state],
203
+ outputs=[train_out, state],
204
+ )
205
+
206
+ predict_btn.click(
207
+ fn=predict_address,
208
+ inputs=[address_in, state],
209
+ outputs=[pred_out],
210
+ )
211
+
212
+ clear_btn.click(
213
+ fn=clear_model,
214
+ inputs=[state],
215
+ outputs=[train_out, state],
216
+ )
217
+
218
+ if __name__ == "__main__":
219
+ demo.launch()