pujithapsx commited on
Commit
28b0196
·
verified ·
1 Parent(s): c5a3a56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -107
app.py CHANGED
@@ -1,74 +1,85 @@
1
  import gradio as gr
2
- from sentence_transformers import CrossEncoder
 
3
 
4
- # Load fine-tuned model from Hub
5
  MODEL_NAME = "pujithapsx/address-crossencoder-bge-reranker-v2-m3-finetuned"
 
6
 
7
- print("Loading model...")
8
- #model = CrossEncoder(MODEL_NAME)
9
- # model = CrossEncoder(
10
- # MODEL_NAME,
11
- # trust_remote_code=True, # allow custom model code
12
- # automodel_args={"ignore_mismatched_sizes": True}, # skip size mismatch errors
13
- # config_args={"model_type": "bert"}, # inject missing model_type
14
- # )
15
- config = BertConfig.from_pretrained(MODEL_NAME) # BertConfig doesn't need model_type
16
- hf_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, config=config)
17
- print("Model loaded successfully!")
18
 
19
- # STATIC THRESHOLD
20
- THRESHOLD = 0.75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  def predict_similarity(input1, input2):
24
- """
25
- Predict similarity between two addresses using static threshold.
26
- Returns: Similarity %, Match/No Match result, and confidence bar value.
27
- """
28
  if not input1.strip() or not input2.strip():
29
  return "—", "⚠️ Please provide both addresses", 0
30
 
31
- score = model.predict([[input1, input2]])[0]
32
  similarity_pct = score * 100
33
 
34
  if score >= THRESHOLD:
35
- result = "✅ MATCH"
36
  confidence_label = "High" if score > 0.85 else "Medium"
37
  else:
38
- result = "❌ NO MATCH"
39
  confidence_label = "High" if score < 0.40 else "Medium"
40
 
41
- similarity_str = f"{similarity_pct:.2f}%"
42
- result_str = f"{result} • Confidence: {confidence_label}"
43
-
44
- return similarity_str, result_str, float(similarity_pct)
45
 
46
 
47
- # ── Custom CSS ──────────────────────────────────────────────────────────────
48
  custom_css = """
49
  @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;500;600&display=swap');
50
 
51
  :root {
52
- --bg: #0b0f1a;
53
- --surface: #111827;
54
- --border: #1f2d45;
55
- --accent: #38bdf8;
56
- --accent2: #818cf8;
57
- --green: #34d399;
58
- --red: #f87171;
59
- --yellow: #fbbf24;
60
- --text: #e2e8f0;
61
- --muted: #64748b;
62
- --radius: 12px;
63
  }
64
-
65
  body, .gradio-container {
66
  background: var(--bg) !important;
67
  font-family: 'DM Sans', sans-serif !important;
68
  color: var(--text) !important;
69
  }
70
-
71
- /* ── header ── */
72
  #header-box {
73
  background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%);
74
  border: 1px solid var(--border);
@@ -92,7 +103,6 @@ body, .gradio-container {
92
  font-weight: 700 !important;
93
  color: var(--accent) !important;
94
  margin: 0 0 8px !important;
95
- letter-spacing: -0.5px;
96
  }
97
  #header-box p {
98
  color: var(--muted) !important;
@@ -112,10 +122,7 @@ body, .gradio-container {
112
  margin-right: 8px;
113
  margin-top: 12px;
114
  }
115
-
116
- /* ── input cards ── */
117
- .input-card textarea,
118
- .input-card input {
119
  background: var(--surface) !important;
120
  border: 1px solid var(--border) !important;
121
  border-radius: var(--radius) !important;
@@ -125,8 +132,7 @@ body, .gradio-container {
125
  padding: 14px 16px !important;
126
  transition: border-color 0.2s;
127
  }
128
- .input-card textarea:focus,
129
- .input-card input:focus {
130
  border-color: var(--accent) !important;
131
  box-shadow: 0 0 0 3px rgba(56,189,248,0.1) !important;
132
  }
@@ -137,8 +143,6 @@ label span {
137
  letter-spacing: 0.5px;
138
  text-transform: uppercase;
139
  }
140
-
141
- /* ── button ── */
142
  #run-btn {
143
  background: linear-gradient(135deg, var(--accent) 0%, var(--accent2) 100%) !important;
144
  border: none !important;
@@ -147,19 +151,14 @@ label span {
147
  font-family: 'Space Mono', monospace !important;
148
  font-size: 0.9rem !important;
149
  font-weight: 700 !important;
150
- letter-spacing: 0.5px;
151
  padding: 14px 32px !important;
152
- cursor: pointer;
153
- transition: opacity 0.2s, transform 0.15s;
154
  width: 100%;
155
  margin-top: 8px;
 
156
  }
157
- #run-btn:hover { opacity: 0.9; transform: translateY(-1px); }
158
  #run-btn:active { transform: translateY(0); }
159
-
160
- /* ── output cards ── */
161
- .output-card textarea,
162
- .output-card input {
163
  background: #0d1424 !important;
164
  border: 1px solid var(--border) !important;
165
  border-radius: var(--radius) !important;
@@ -169,13 +168,7 @@ label span {
169
  font-weight: 700 !important;
170
  text-align: center;
171
  }
172
-
173
- /* ── slider (score bar) ── */
174
- .score-slider input[type=range] {
175
- accent-color: var(--accent);
176
- }
177
-
178
- /* ── examples table ── */
179
  .gr-samples-table {
180
  background: var(--surface) !important;
181
  border: 1px solid var(--border) !important;
@@ -186,18 +179,10 @@ label span {
186
  font-size: 0.72rem !important;
187
  color: var(--muted) !important;
188
  text-transform: uppercase;
189
- letter-spacing: 0.5px;
190
  background: #0d1424 !important;
191
  }
192
- .gr-samples-table td {
193
- color: var(--text) !important;
194
- font-size: 0.88rem !important;
195
- }
196
- .gr-samples-table tr:hover td {
197
- background: rgba(56,189,248,0.04) !important;
198
- }
199
-
200
- /* ── info footer ── */
201
  #footer-info {
202
  background: var(--surface);
203
  border: 1px solid var(--border);
@@ -215,10 +200,9 @@ label span {
215
  #footer-info span { color: var(--accent) !important; }
216
  """
217
 
218
- # ── Gradio UI ────────────────────────────────────────────────────────────────
219
  with gr.Blocks(css=custom_css, title="Address Entity Matcher") as demo:
220
 
221
- # Header
222
  gr.HTML("""
223
  <div id="header-box">
224
  <h1>📍 Address Entity Matcher</h1>
@@ -226,13 +210,11 @@ with gr.Blocks(css=custom_css, title="Address Entity Matcher") as demo:
226
  Enter two addresses to determine whether they refer to the same location.<br>
227
  Powered by a fine-tuned <strong>BGE-Reranker-v2-m3</strong> cross-encoder model.
228
  </p>
229
- <span class="badge">CrossEncoder</span>
230
  <span class="badge">BGE-Reranker-v2-m3</span>
231
  <span class="badge">Threshold: 0.75</span>
232
  </div>
233
  """)
234
 
235
- # Inputs
236
  with gr.Row(equal_height=True):
237
  with gr.Column(elem_classes="input-card"):
238
  input1 = gr.Textbox(
@@ -249,55 +231,39 @@ with gr.Blocks(css=custom_css, title="Address Entity Matcher") as demo:
249
 
250
  btn = gr.Button("🔎 Check Match", elem_id="run-btn", variant="primary")
251
 
252
- # Outputs
253
  with gr.Row(equal_height=True):
254
  with gr.Column(elem_classes="output-card"):
255
- similarity_output = gr.Textbox(
256
- label="Similarity Score",
257
- interactive=False,
258
- placeholder="—",
259
- )
260
  with gr.Column(elem_classes="output-card"):
261
- result_output = gr.Textbox(
262
- label="Match Result",
263
- interactive=False,
264
- placeholder="—",
265
- )
266
 
267
  score_bar = gr.Slider(
268
- minimum=0,
269
- maximum=100,
270
- value=0,
271
- step=0.01,
272
  label="Score Visualisation (threshold line: 75%)",
273
- interactive=False,
274
- elem_classes="score-slider",
275
  )
276
 
277
- # Examples
278
  gr.Examples(
279
  examples=[
280
- ["Flat 12-B Sector 5 Noida", "Flat 23-B Sector 5 Noida"],
281
- ["Phase 4 Whitefield Bangalore", "Whitefield Phase V Bangalore"],
282
- ["Thirteen I 7th Avenue Adyar Chennai", "13 Seventh Avenue Adyar Chennai"],
283
- ["Twenty Nine A Second Cross Koramangala Bengaluru", "47 Forty Seven B Third Street Indiranagar Bengaluru"],
284
- ["Plot 8 Banjara Hills Hyderabad", "Plot 8 Banjara Hills Hyderabad"],
285
- ["House No 4 Lane 2 DLF Phase 1 Gurugram", "H.No 4/2 DLF Phase One Gurgaon"],
286
  ],
287
  inputs=[input1, input2],
288
  label="📋 Try these examples",
289
  )
290
 
291
- # Footer info
292
  gr.HTML(f"""
293
  <div id="footer-info">
294
- <p>🤖 <strong>Model:</strong> <span>pujithapsx/address-crossencoder-bge-reranker-v2-m3-finetuned</span></p>
295
  <p>📏 <strong>Threshold:</strong> <span>{THRESHOLD}</span> — Score ≥ {THRESHOLD} → MATCH &nbsp;|&nbsp; Score &lt; {THRESHOLD} → NO MATCH</p>
296
  <p>🏷️ <strong>Confidence:</strong> High (score &gt; 0.85 or &lt; 0.40) &nbsp;|&nbsp; Medium (otherwise)</p>
297
  </div>
298
  """)
299
 
300
- # Wiring
301
  btn.click(
302
  fn=predict_similarity,
303
  inputs=[input1, input2],
@@ -305,4 +271,4 @@ with gr.Blocks(css=custom_css, title="Address Entity Matcher") as demo:
305
  )
306
 
307
  if __name__ == "__main__":
308
- demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, BertConfig
4
 
5
+ # ── Model setup ──────────────────────────────────────────────────────────────
6
  MODEL_NAME = "pujithapsx/address-crossencoder-bge-reranker-v2-m3-finetuned"
7
+ THRESHOLD = 0.75
8
 
9
+ print("Loading tokenizer...")
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
 
11
 
12
+ print("Loading model weights...")
13
+ # The model's config.json is missing "model_type", which breaks AutoConfig/CrossEncoder.
14
+ # BertConfig.from_pretrained() reads the repo's config.json without needing model_type.
15
+ config = BertConfig.from_pretrained(MODEL_NAME)
16
+ hf_model = AutoModelForSequenceClassification.from_pretrained(
17
+ MODEL_NAME,
18
+ config=config,
19
+ ignore_mismatched_sizes=True,
20
+ )
21
+ hf_model.eval()
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ hf_model.to(device)
24
+ print(f"Model loaded on {device}!")
25
+
26
+
27
+ # ── Inference ─────────────────────────────────────────────────────────────────
28
+ def _score_pair(text_a: str, text_b: str) -> float:
29
+ """Tokenise a pair and return a 0-1 similarity score."""
30
+ features = tokenizer(
31
+ text_a, text_b,
32
+ padding=True,
33
+ truncation=True,
34
+ max_length=512,
35
+ return_tensors="pt",
36
+ )
37
+ features = {k: v.to(device) for k, v in features.items()}
38
+ with torch.no_grad():
39
+ logits = hf_model(**features).logits
40
+ # single logit → sigmoid; two logits → softmax of positive class
41
+ if logits.shape[-1] == 1:
42
+ return torch.sigmoid(logits[0, 0]).item()
43
+ else:
44
+ return torch.softmax(logits[0], dim=-1)[1].item()
45
 
46
 
47
  def predict_similarity(input1, input2):
 
 
 
 
48
  if not input1.strip() or not input2.strip():
49
  return "—", "⚠️ Please provide both addresses", 0
50
 
51
+ score = _score_pair(input1.strip(), input2.strip())
52
  similarity_pct = score * 100
53
 
54
  if score >= THRESHOLD:
55
+ result = "✅ MATCH"
56
  confidence_label = "High" if score > 0.85 else "Medium"
57
  else:
58
+ result = "❌ NO MATCH"
59
  confidence_label = "High" if score < 0.40 else "Medium"
60
 
61
+ return f"{similarity_pct:.2f}%", f"{result} • Confidence: {confidence_label}", float(similarity_pct)
 
 
 
62
 
63
 
64
+ # ── Custom CSS ────────────────────────────────────────────────────────────────
65
  custom_css = """
66
  @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;500;600&display=swap');
67
 
68
  :root {
69
+ --bg: #0b0f1a;
70
+ --surface: #111827;
71
+ --border: #1f2d45;
72
+ --accent: #38bdf8;
73
+ --accent2: #818cf8;
74
+ --text: #e2e8f0;
75
+ --muted: #64748b;
76
+ --radius: 12px;
 
 
 
77
  }
 
78
  body, .gradio-container {
79
  background: var(--bg) !important;
80
  font-family: 'DM Sans', sans-serif !important;
81
  color: var(--text) !important;
82
  }
 
 
83
  #header-box {
84
  background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%);
85
  border: 1px solid var(--border);
 
103
  font-weight: 700 !important;
104
  color: var(--accent) !important;
105
  margin: 0 0 8px !important;
 
106
  }
107
  #header-box p {
108
  color: var(--muted) !important;
 
122
  margin-right: 8px;
123
  margin-top: 12px;
124
  }
125
+ .input-card textarea, .input-card input {
 
 
 
126
  background: var(--surface) !important;
127
  border: 1px solid var(--border) !important;
128
  border-radius: var(--radius) !important;
 
132
  padding: 14px 16px !important;
133
  transition: border-color 0.2s;
134
  }
135
+ .input-card textarea:focus, .input-card input:focus {
 
136
  border-color: var(--accent) !important;
137
  box-shadow: 0 0 0 3px rgba(56,189,248,0.1) !important;
138
  }
 
143
  letter-spacing: 0.5px;
144
  text-transform: uppercase;
145
  }
 
 
146
  #run-btn {
147
  background: linear-gradient(135deg, var(--accent) 0%, var(--accent2) 100%) !important;
148
  border: none !important;
 
151
  font-family: 'Space Mono', monospace !important;
152
  font-size: 0.9rem !important;
153
  font-weight: 700 !important;
 
154
  padding: 14px 32px !important;
 
 
155
  width: 100%;
156
  margin-top: 8px;
157
+ transition: opacity 0.2s, transform 0.15s;
158
  }
159
+ #run-btn:hover { opacity: 0.9; transform: translateY(-1px); }
160
  #run-btn:active { transform: translateY(0); }
161
+ .output-card textarea, .output-card input {
 
 
 
162
  background: #0d1424 !important;
163
  border: 1px solid var(--border) !important;
164
  border-radius: var(--radius) !important;
 
168
  font-weight: 700 !important;
169
  text-align: center;
170
  }
171
+ .score-slider input[type=range] { accent-color: var(--accent); }
 
 
 
 
 
 
172
  .gr-samples-table {
173
  background: var(--surface) !important;
174
  border: 1px solid var(--border) !important;
 
179
  font-size: 0.72rem !important;
180
  color: var(--muted) !important;
181
  text-transform: uppercase;
 
182
  background: #0d1424 !important;
183
  }
184
+ .gr-samples-table td { color: var(--text) !important; font-size: 0.88rem !important; }
185
+ .gr-samples-table tr:hover td { background: rgba(56,189,248,0.04) !important; }
 
 
 
 
 
 
 
186
  #footer-info {
187
  background: var(--surface);
188
  border: 1px solid var(--border);
 
200
  #footer-info span { color: var(--accent) !important; }
201
  """
202
 
203
+ # ── Gradio UI ────────────────────────────────────────────────────────────────
204
  with gr.Blocks(css=custom_css, title="Address Entity Matcher") as demo:
205
 
 
206
  gr.HTML("""
207
  <div id="header-box">
208
  <h1>📍 Address Entity Matcher</h1>
 
210
  Enter two addresses to determine whether they refer to the same location.<br>
211
  Powered by a fine-tuned <strong>BGE-Reranker-v2-m3</strong> cross-encoder model.
212
  </p>
 
213
  <span class="badge">BGE-Reranker-v2-m3</span>
214
  <span class="badge">Threshold: 0.75</span>
215
  </div>
216
  """)
217
 
 
218
  with gr.Row(equal_height=True):
219
  with gr.Column(elem_classes="input-card"):
220
  input1 = gr.Textbox(
 
231
 
232
  btn = gr.Button("🔎 Check Match", elem_id="run-btn", variant="primary")
233
 
 
234
  with gr.Row(equal_height=True):
235
  with gr.Column(elem_classes="output-card"):
236
+ similarity_output = gr.Textbox(label="Similarity Score", interactive=False, placeholder="—")
 
 
 
 
237
  with gr.Column(elem_classes="output-card"):
238
+ result_output = gr.Textbox(label="Match Result", interactive=False, placeholder="—")
 
 
 
 
239
 
240
  score_bar = gr.Slider(
241
+ minimum=0, maximum=100, value=0, step=0.01,
 
 
 
242
  label="Score Visualisation (threshold line: 75%)",
243
+ interactive=False, elem_classes="score-slider",
 
244
  )
245
 
 
246
  gr.Examples(
247
  examples=[
248
+ ["Flat 12-B Sector 5 Noida", "Flat 23-B Sector 5 Noida"],
249
+ ["Phase 4 Whitefield Bangalore", "Whitefield Phase V Bangalore"],
250
+ ["Thirteen I 7th Avenue Adyar Chennai", "13 Seventh Avenue Adyar Chennai"],
251
+ ["Twenty Nine A Second Cross Koramangala Bengaluru", "47 Forty Seven B Third Street Indiranagar Bengaluru"],
252
+ ["Plot 8 Banjara Hills Hyderabad", "Plot 8 Banjara Hills Hyderabad"],
253
+ ["House No 4 Lane 2 DLF Phase 1 Gurugram", "H.No 4/2 DLF Phase One Gurgaon"],
254
  ],
255
  inputs=[input1, input2],
256
  label="📋 Try these examples",
257
  )
258
 
 
259
  gr.HTML(f"""
260
  <div id="footer-info">
261
+ <p>🤖 <strong>Model:</strong> <span>{MODEL_NAME}</span></p>
262
  <p>📏 <strong>Threshold:</strong> <span>{THRESHOLD}</span> — Score ≥ {THRESHOLD} → MATCH &nbsp;|&nbsp; Score &lt; {THRESHOLD} → NO MATCH</p>
263
  <p>🏷️ <strong>Confidence:</strong> High (score &gt; 0.85 or &lt; 0.40) &nbsp;|&nbsp; Medium (otherwise)</p>
264
  </div>
265
  """)
266
 
 
267
  btn.click(
268
  fn=predict_similarity,
269
  inputs=[input1, input2],
 
271
  )
272
 
273
  if __name__ == "__main__":
274
+ demo.launch()