pujithapsx commited on
Commit
1bc6164
·
verified ·
1 Parent(s): 28b0196

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -69
app.py CHANGED
@@ -1,85 +1,72 @@
1
  import gradio as gr
 
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, BertConfig
4
 
5
- # ── Model setup ──────────────────────────────────────────────────────────────
6
  MODEL_NAME = "pujithapsx/address-crossencoder-bge-reranker-v2-m3-finetuned"
7
- THRESHOLD = 0.75
8
 
9
- print("Loading tokenizer...")
10
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
-
12
- print("Loading model weights...")
13
- # The model's config.json is missing "model_type", which breaks AutoConfig/CrossEncoder.
14
- # BertConfig.from_pretrained() reads the repo's config.json without needing model_type.
15
- config = BertConfig.from_pretrained(MODEL_NAME)
16
- hf_model = AutoModelForSequenceClassification.from_pretrained(
17
  MODEL_NAME,
18
- config=config,
19
- ignore_mismatched_sizes=True,
20
  )
21
- hf_model.eval()
22
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
- hf_model.to(device)
24
- print(f"Model loaded on {device}!")
25
-
26
 
27
- # ── Inference ─────────────────────────────────────────────────────────────────
28
- def _score_pair(text_a: str, text_b: str) -> float:
29
- """Tokenise a pair and return a 0-1 similarity score."""
30
- features = tokenizer(
31
- text_a, text_b,
32
- padding=True,
33
- truncation=True,
34
- max_length=512,
35
- return_tensors="pt",
36
- )
37
- features = {k: v.to(device) for k, v in features.items()}
38
- with torch.no_grad():
39
- logits = hf_model(**features).logits
40
- # single logit → sigmoid; two logits → softmax of positive class
41
- if logits.shape[-1] == 1:
42
- return torch.sigmoid(logits[0, 0]).item()
43
- else:
44
- return torch.softmax(logits[0], dim=-1)[1].item()
45
 
46
 
47
  def predict_similarity(input1, input2):
 
 
 
 
48
  if not input1.strip() or not input2.strip():
49
  return "—", "⚠️ Please provide both addresses", 0
50
 
51
- score = _score_pair(input1.strip(), input2.strip())
52
  similarity_pct = score * 100
53
 
54
  if score >= THRESHOLD:
55
- result = "✅ MATCH"
56
  confidence_label = "High" if score > 0.85 else "Medium"
57
  else:
58
- result = "❌ NO MATCH"
59
  confidence_label = "High" if score < 0.40 else "Medium"
60
 
61
- return f"{similarity_pct:.2f}%", f"{result} • Confidence: {confidence_label}", float(similarity_pct)
 
 
 
62
 
63
 
64
- # ── Custom CSS ────────────────────────────────────────────────────────────────
65
  custom_css = """
66
  @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;500;600&display=swap');
67
 
68
  :root {
69
- --bg: #0b0f1a;
70
- --surface: #111827;
71
- --border: #1f2d45;
72
- --accent: #38bdf8;
73
- --accent2: #818cf8;
74
- --text: #e2e8f0;
75
- --muted: #64748b;
76
- --radius: 12px;
 
 
 
77
  }
 
78
  body, .gradio-container {
79
  background: var(--bg) !important;
80
  font-family: 'DM Sans', sans-serif !important;
81
  color: var(--text) !important;
82
  }
 
 
83
  #header-box {
84
  background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%);
85
  border: 1px solid var(--border);
@@ -103,6 +90,7 @@ body, .gradio-container {
103
  font-weight: 700 !important;
104
  color: var(--accent) !important;
105
  margin: 0 0 8px !important;
 
106
  }
107
  #header-box p {
108
  color: var(--muted) !important;
@@ -122,7 +110,10 @@ body, .gradio-container {
122
  margin-right: 8px;
123
  margin-top: 12px;
124
  }
125
- .input-card textarea, .input-card input {
 
 
 
126
  background: var(--surface) !important;
127
  border: 1px solid var(--border) !important;
128
  border-radius: var(--radius) !important;
@@ -132,7 +123,8 @@ body, .gradio-container {
132
  padding: 14px 16px !important;
133
  transition: border-color 0.2s;
134
  }
135
- .input-card textarea:focus, .input-card input:focus {
 
136
  border-color: var(--accent) !important;
137
  box-shadow: 0 0 0 3px rgba(56,189,248,0.1) !important;
138
  }
@@ -143,6 +135,8 @@ label span {
143
  letter-spacing: 0.5px;
144
  text-transform: uppercase;
145
  }
 
 
146
  #run-btn {
147
  background: linear-gradient(135deg, var(--accent) 0%, var(--accent2) 100%) !important;
148
  border: none !important;
@@ -151,14 +145,19 @@ label span {
151
  font-family: 'Space Mono', monospace !important;
152
  font-size: 0.9rem !important;
153
  font-weight: 700 !important;
 
154
  padding: 14px 32px !important;
 
 
155
  width: 100%;
156
  margin-top: 8px;
157
- transition: opacity 0.2s, transform 0.15s;
158
  }
159
- #run-btn:hover { opacity: 0.9; transform: translateY(-1px); }
160
  #run-btn:active { transform: translateY(0); }
161
- .output-card textarea, .output-card input {
 
 
 
162
  background: #0d1424 !important;
163
  border: 1px solid var(--border) !important;
164
  border-radius: var(--radius) !important;
@@ -168,7 +167,13 @@ label span {
168
  font-weight: 700 !important;
169
  text-align: center;
170
  }
171
- .score-slider input[type=range] { accent-color: var(--accent); }
 
 
 
 
 
 
172
  .gr-samples-table {
173
  background: var(--surface) !important;
174
  border: 1px solid var(--border) !important;
@@ -179,10 +184,18 @@ label span {
179
  font-size: 0.72rem !important;
180
  color: var(--muted) !important;
181
  text-transform: uppercase;
 
182
  background: #0d1424 !important;
183
  }
184
- .gr-samples-table td { color: var(--text) !important; font-size: 0.88rem !important; }
185
- .gr-samples-table tr:hover td { background: rgba(56,189,248,0.04) !important; }
 
 
 
 
 
 
 
186
  #footer-info {
187
  background: var(--surface);
188
  border: 1px solid var(--border);
@@ -200,9 +213,10 @@ label span {
200
  #footer-info span { color: var(--accent) !important; }
201
  """
202
 
203
- # ── Gradio UI ────────────────────────────────────────────────────────────────
204
  with gr.Blocks(css=custom_css, title="Address Entity Matcher") as demo:
205
 
 
206
  gr.HTML("""
207
  <div id="header-box">
208
  <h1>📍 Address Entity Matcher</h1>
@@ -210,11 +224,13 @@ with gr.Blocks(css=custom_css, title="Address Entity Matcher") as demo:
210
  Enter two addresses to determine whether they refer to the same location.<br>
211
  Powered by a fine-tuned <strong>BGE-Reranker-v2-m3</strong> cross-encoder model.
212
  </p>
 
213
  <span class="badge">BGE-Reranker-v2-m3</span>
214
  <span class="badge">Threshold: 0.75</span>
215
  </div>
216
  """)
217
 
 
218
  with gr.Row(equal_height=True):
219
  with gr.Column(elem_classes="input-card"):
220
  input1 = gr.Textbox(
@@ -231,39 +247,55 @@ with gr.Blocks(css=custom_css, title="Address Entity Matcher") as demo:
231
 
232
  btn = gr.Button("🔎 Check Match", elem_id="run-btn", variant="primary")
233
 
 
234
  with gr.Row(equal_height=True):
235
  with gr.Column(elem_classes="output-card"):
236
- similarity_output = gr.Textbox(label="Similarity Score", interactive=False, placeholder="—")
 
 
 
 
237
  with gr.Column(elem_classes="output-card"):
238
- result_output = gr.Textbox(label="Match Result", interactive=False, placeholder="—")
 
 
 
 
239
 
240
  score_bar = gr.Slider(
241
- minimum=0, maximum=100, value=0, step=0.01,
 
 
 
242
  label="Score Visualisation (threshold line: 75%)",
243
- interactive=False, elem_classes="score-slider",
 
244
  )
245
 
 
246
  gr.Examples(
247
  examples=[
248
- ["Flat 12-B Sector 5 Noida", "Flat 23-B Sector 5 Noida"],
249
- ["Phase 4 Whitefield Bangalore", "Whitefield Phase V Bangalore"],
250
- ["Thirteen I 7th Avenue Adyar Chennai", "13 Seventh Avenue Adyar Chennai"],
251
- ["Twenty Nine A Second Cross Koramangala Bengaluru", "47 Forty Seven B Third Street Indiranagar Bengaluru"],
252
- ["Plot 8 Banjara Hills Hyderabad", "Plot 8 Banjara Hills Hyderabad"],
253
- ["House No 4 Lane 2 DLF Phase 1 Gurugram", "H.No 4/2 DLF Phase One Gurgaon"],
254
  ],
255
  inputs=[input1, input2],
256
  label="📋 Try these examples",
257
  )
258
 
 
259
  gr.HTML(f"""
260
  <div id="footer-info">
261
- <p>🤖 <strong>Model:</strong> <span>{MODEL_NAME}</span></p>
262
  <p>📏 <strong>Threshold:</strong> <span>{THRESHOLD}</span> — Score ≥ {THRESHOLD} → MATCH &nbsp;|&nbsp; Score &lt; {THRESHOLD} → NO MATCH</p>
263
  <p>🏷️ <strong>Confidence:</strong> High (score &gt; 0.85 or &lt; 0.40) &nbsp;|&nbsp; Medium (otherwise)</p>
264
  </div>
265
  """)
266
 
 
267
  btn.click(
268
  fn=predict_similarity,
269
  inputs=[input1, input2],
 
1
  import gradio as gr
2
+ from sentence_transformers import CrossEncoder
3
  import torch
 
4
 
5
+ # Load fine-tuned model from Hub
6
  MODEL_NAME = "pujithapsx/address-crossencoder-bge-reranker-v2-m3-finetuned"
 
7
 
8
+ print("Loading model...")
9
+ # Use CrossEncoder directly - it handles the model format correctly
10
+ model = CrossEncoder(
 
 
 
 
 
11
  MODEL_NAME,
12
+ trust_remote_code=True,
13
+ device="cpu" # Use "cuda" if GPU is available in your HF Space
14
  )
15
+ print("Model loaded successfully!")
 
 
 
 
16
 
17
+ # STATIC THRESHOLD
18
+ THRESHOLD = 0.75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  def predict_similarity(input1, input2):
22
+ """
23
+ Predict similarity between two addresses using static threshold.
24
+ Returns: Similarity %, Match/No Match result, and confidence bar value.
25
+ """
26
  if not input1.strip() or not input2.strip():
27
  return "—", "⚠️ Please provide both addresses", 0
28
 
29
+ score = model.predict([[input1, input2]])[0]
30
  similarity_pct = score * 100
31
 
32
  if score >= THRESHOLD:
33
+ result = "✅ MATCH"
34
  confidence_label = "High" if score > 0.85 else "Medium"
35
  else:
36
+ result = "❌ NO MATCH"
37
  confidence_label = "High" if score < 0.40 else "Medium"
38
 
39
+ similarity_str = f"{similarity_pct:.2f}%"
40
+ result_str = f"{result} • Confidence: {confidence_label}"
41
+
42
+ return similarity_str, result_str, float(similarity_pct)
43
 
44
 
45
+ # ── Custom CSS ──────────────────────────────────────────────────────────────
46
  custom_css = """
47
  @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;500;600&display=swap');
48
 
49
  :root {
50
+ --bg: #0b0f1a;
51
+ --surface: #111827;
52
+ --border: #1f2d45;
53
+ --accent: #38bdf8;
54
+ --accent2: #818cf8;
55
+ --green: #34d399;
56
+ --red: #f87171;
57
+ --yellow: #fbbf24;
58
+ --text: #e2e8f0;
59
+ --muted: #64748b;
60
+ --radius: 12px;
61
  }
62
+
63
  body, .gradio-container {
64
  background: var(--bg) !important;
65
  font-family: 'DM Sans', sans-serif !important;
66
  color: var(--text) !important;
67
  }
68
+
69
+ /* ── header ── */
70
  #header-box {
71
  background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%);
72
  border: 1px solid var(--border);
 
90
  font-weight: 700 !important;
91
  color: var(--accent) !important;
92
  margin: 0 0 8px !important;
93
+ letter-spacing: -0.5px;
94
  }
95
  #header-box p {
96
  color: var(--muted) !important;
 
110
  margin-right: 8px;
111
  margin-top: 12px;
112
  }
113
+
114
+ /* ── input cards ── */
115
+ .input-card textarea,
116
+ .input-card input {
117
  background: var(--surface) !important;
118
  border: 1px solid var(--border) !important;
119
  border-radius: var(--radius) !important;
 
123
  padding: 14px 16px !important;
124
  transition: border-color 0.2s;
125
  }
126
+ .input-card textarea:focus,
127
+ .input-card input:focus {
128
  border-color: var(--accent) !important;
129
  box-shadow: 0 0 0 3px rgba(56,189,248,0.1) !important;
130
  }
 
135
  letter-spacing: 0.5px;
136
  text-transform: uppercase;
137
  }
138
+
139
+ /* ── button ── */
140
  #run-btn {
141
  background: linear-gradient(135deg, var(--accent) 0%, var(--accent2) 100%) !important;
142
  border: none !important;
 
145
  font-family: 'Space Mono', monospace !important;
146
  font-size: 0.9rem !important;
147
  font-weight: 700 !important;
148
+ letter-spacing: 0.5px;
149
  padding: 14px 32px !important;
150
+ cursor: pointer;
151
+ transition: opacity 0.2s, transform 0.15s;
152
  width: 100%;
153
  margin-top: 8px;
 
154
  }
155
+ #run-btn:hover { opacity: 0.9; transform: translateY(-1px); }
156
  #run-btn:active { transform: translateY(0); }
157
+
158
+ /* ── output cards ── */
159
+ .output-card textarea,
160
+ .output-card input {
161
  background: #0d1424 !important;
162
  border: 1px solid var(--border) !important;
163
  border-radius: var(--radius) !important;
 
167
  font-weight: 700 !important;
168
  text-align: center;
169
  }
170
+
171
+ /* ── slider (score bar) ── */
172
+ .score-slider input[type=range] {
173
+ accent-color: var(--accent);
174
+ }
175
+
176
+ /* ── examples table ── */
177
  .gr-samples-table {
178
  background: var(--surface) !important;
179
  border: 1px solid var(--border) !important;
 
184
  font-size: 0.72rem !important;
185
  color: var(--muted) !important;
186
  text-transform: uppercase;
187
+ letter-spacing: 0.5px;
188
  background: #0d1424 !important;
189
  }
190
+ .gr-samples-table td {
191
+ color: var(--text) !important;
192
+ font-size: 0.88rem !important;
193
+ }
194
+ .gr-samples-table tr:hover td {
195
+ background: rgba(56,189,248,0.04) !important;
196
+ }
197
+
198
+ /* ── info footer ── */
199
  #footer-info {
200
  background: var(--surface);
201
  border: 1px solid var(--border);
 
213
  #footer-info span { color: var(--accent) !important; }
214
  """
215
 
216
+ # ── Gradio UI ────────────────────────────────────────────────────────────────
217
  with gr.Blocks(css=custom_css, title="Address Entity Matcher") as demo:
218
 
219
+ # Header
220
  gr.HTML("""
221
  <div id="header-box">
222
  <h1>📍 Address Entity Matcher</h1>
 
224
  Enter two addresses to determine whether they refer to the same location.<br>
225
  Powered by a fine-tuned <strong>BGE-Reranker-v2-m3</strong> cross-encoder model.
226
  </p>
227
+ <span class="badge">CrossEncoder</span>
228
  <span class="badge">BGE-Reranker-v2-m3</span>
229
  <span class="badge">Threshold: 0.75</span>
230
  </div>
231
  """)
232
 
233
+ # Inputs
234
  with gr.Row(equal_height=True):
235
  with gr.Column(elem_classes="input-card"):
236
  input1 = gr.Textbox(
 
247
 
248
  btn = gr.Button("🔎 Check Match", elem_id="run-btn", variant="primary")
249
 
250
+ # Outputs
251
  with gr.Row(equal_height=True):
252
  with gr.Column(elem_classes="output-card"):
253
+ similarity_output = gr.Textbox(
254
+ label="Similarity Score",
255
+ interactive=False,
256
+ placeholder="—",
257
+ )
258
  with gr.Column(elem_classes="output-card"):
259
+ result_output = gr.Textbox(
260
+ label="Match Result",
261
+ interactive=False,
262
+ placeholder="—",
263
+ )
264
 
265
  score_bar = gr.Slider(
266
+ minimum=0,
267
+ maximum=100,
268
+ value=0,
269
+ step=0.01,
270
  label="Score Visualisation (threshold line: 75%)",
271
+ interactive=False,
272
+ elem_classes="score-slider",
273
  )
274
 
275
+ # Examples
276
  gr.Examples(
277
  examples=[
278
+ ["Flat 12-B Sector 5 Noida", "Flat 23-B Sector 5 Noida"],
279
+ ["Phase 4 Whitefield Bangalore", "Whitefield Phase V Bangalore"],
280
+ ["Thirteen I 7th Avenue Adyar Chennai", "13 Seventh Avenue Adyar Chennai"],
281
+ ["Twenty Nine A Second Cross Koramangala Bengaluru", "47 Forty Seven B Third Street Indiranagar Bengaluru"],
282
+ ["Plot 8 Banjara Hills Hyderabad", "Plot 8 Banjara Hills Hyderabad"],
283
+ ["House No 4 Lane 2 DLF Phase 1 Gurugram", "H.No 4/2 DLF Phase One Gurgaon"],
284
  ],
285
  inputs=[input1, input2],
286
  label="📋 Try these examples",
287
  )
288
 
289
+ # Footer info
290
  gr.HTML(f"""
291
  <div id="footer-info">
292
+ <p>🤖 <strong>Model:</strong> <span>pujithapsx/address-crossencoder-bge-reranker-v2-m3-finetuned</span></p>
293
  <p>📏 <strong>Threshold:</strong> <span>{THRESHOLD}</span> — Score ≥ {THRESHOLD} → MATCH &nbsp;|&nbsp; Score &lt; {THRESHOLD} → NO MATCH</p>
294
  <p>🏷️ <strong>Confidence:</strong> High (score &gt; 0.85 or &lt; 0.40) &nbsp;|&nbsp; Medium (otherwise)</p>
295
  </div>
296
  """)
297
 
298
+ # Wiring
299
  btn.click(
300
  fn=predict_similarity,
301
  inputs=[input1, input2],