RISHABH KUMAR commited on
Commit
162b166
·
1 Parent(s): 19f9b01

Add Quora duplicate detector Gradio app

Browse files
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app for Quora Duplicate Question Detector.
3
+ Deploy to Hugging Face Spaces with Gradio SDK.
4
+ """
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ ROOT = Path(__file__).resolve().parent
9
+ sys.path.insert(0, str(ROOT))
10
+ sys.path.insert(0, str(ROOT / "streamlit-app"))
11
+
12
+ import nltk
13
+ nltk.download("stopwords", quiet=True)
14
+
15
+ import helper
16
+
17
+ import gradio as gr
18
+
19
+
20
+ def predict_fn(q1: str, q2: str, model_name: str):
21
+ """Run prediction and return formatted output."""
22
+ q1_clean = (q1 or "").strip()
23
+ q2_clean = (q2 or "").strip()
24
+
25
+ if not q1_clean or not q2_clean:
26
+ return "⚠️ Please enter both questions.", 0.0
27
+ if len(q1_clean) < 3 or len(q2_clean) < 3:
28
+ return "⚠️ Questions should be at least 3 characters.", 0.0
29
+
30
+ try:
31
+ model_type = "classical" if "Classical" in model_name else "transformer"
32
+ pred, proba = helper.predict(q1_clean, q2_clean, model_type)
33
+
34
+ if pred:
35
+ msg = "**Duplicate** — These questions likely have the same meaning."
36
+ else:
37
+ msg = "**Not Duplicate** — These questions appear to be different."
38
+
39
+ return msg, proba
40
+ except Exception as e:
41
+ return f"❌ Error: {str(e)}", 0.0
42
+
43
+
44
+ # Build model options
45
+ available = helper.get_available_models()
46
+ if not available:
47
+ raise RuntimeError("No models found. Add models to models/ or configure HF Hub download.")
48
+
49
+ inference_times = helper.get_inference_times()
50
+ model_choices = [helper.get_model_display_name(m) for m in available]
51
+ model_choices_with_time = []
52
+ for m in model_choices:
53
+ key = "classical" if "Classical" in m else "transformer"
54
+ ms = inference_times.get(key, {}).get("mean_ms", 0)
55
+ suffix = f" (~{ms:.0f} ms)" if ms else ""
56
+ model_choices_with_time.append(f"{m}{suffix}")
57
+
58
+ with gr.Blocks(title="Quora Duplicate Detector", theme=gr.themes.Soft()) as demo:
59
+ gr.Markdown("# 🔍 Quora Duplicate Question Pairs")
60
+ gr.Markdown("Enter two questions to check if they are semantically duplicate.")
61
+
62
+ with gr.Row():
63
+ with gr.Column(scale=2):
64
+ q1 = gr.Textbox(
65
+ label="Question 1",
66
+ placeholder="e.g. What is the capital of India?",
67
+ lines=2,
68
+ )
69
+ q2 = gr.Textbox(
70
+ label="Question 2",
71
+ placeholder="e.g. Which city is India's capital?",
72
+ lines=2,
73
+ )
74
+ model_dropdown = gr.Dropdown(
75
+ label="Model",
76
+ choices=model_choices_with_time,
77
+ value=model_choices_with_time[0],
78
+ )
79
+ check_btn = gr.Button("Check", variant="primary")
80
+ with gr.Column(scale=1):
81
+ result_text = gr.Markdown(value="")
82
+ proba_slider = gr.Slider(
83
+ minimum=0,
84
+ maximum=1,
85
+ value=0,
86
+ label="Probability of Duplicate",
87
+ interactive=False,
88
+ )
89
+
90
+ with gr.Accordion("Try example pairs", open=False):
91
+ gr.Examples(
92
+ examples=[
93
+ ["How do I learn Python?", "What is the best way to learn Python programming?"],
94
+ ["What is the capital of France?", "How do I cook pasta?"],
95
+ ],
96
+ inputs=[q1, q2],
97
+ label="",
98
+ )
99
+
100
+ check_btn.click(
101
+ fn=predict_fn,
102
+ inputs=[q1, q2, model_dropdown],
103
+ outputs=[result_text, proba_slider],
104
+ )
105
+
106
+ gr.Markdown("---")
107
+ with gr.Accordion("About", open=False):
108
+ gr.Markdown("""
109
+ This app predicts whether two Quora questions are duplicates (same meaning).
110
+
111
+ **Models:**
112
+ - **Classical**: Random Forest or XGBoost on 25 handcrafted features + TF-IDF
113
+ - **DistilBERT**: Fine-tuned transformer for sentence-pair classification
114
+
115
+ *Built for fun & learning. Results may not always be accurate — use with caution.*
116
+ """)
117
+
118
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core data science
2
+ numpy>=1.24,<3
3
+ pandas>=2.0
4
+ scikit-learn>=1.3
5
+ scipy>=1.11
6
+
7
+ # NLP & text
8
+ nltk>=3.8
9
+ beautifulsoup4>=4.12
10
+ fuzzywuzzy>=0.18
11
+ python-Levenshtein>=0.21
12
+ distance>=0.1.3
13
+
14
+ # Models
15
+ xgboost>=2.0
16
+ lightgbm>=4.0
17
+
18
+ # Embeddings (Phase 2)
19
+ torch>=2.0
20
+ sentence-transformers>=2.2
21
+
22
+ # Transformer fine-tuning
23
+ transformers>=4.30
24
+ datasets>=2.14
25
+ accelerate>=0.20
26
+
27
+ # App (Gradio for HF Spaces)
28
+ gradio>=4.0
29
+ huggingface_hub>=0.20
30
+
31
+ # Visualization
32
+ matplotlib>=3.7
33
+ seaborn>=0.13
34
+ plotly>=5.18
35
+
36
+ # Progress & utils
37
+ tqdm>=4.65
38
+
39
+ # Jupyter (for notebooks)
40
+ jupyter>=1.0
41
+ ipykernel>=6.0
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (159 Bytes). View file
 
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (163 Bytes). View file
 
src/__pycache__/embeddings.cpython-310.pyc ADDED
Binary file (1.26 kB). View file
 
src/__pycache__/embeddings.cpython-312.pyc ADDED
Binary file (1.79 kB). View file
 
src/__pycache__/feature_engineering.cpython-310.pyc ADDED
Binary file (5.06 kB). View file
 
src/__pycache__/feature_engineering.cpython-312.pyc ADDED
Binary file (9.79 kB). View file
 
src/__pycache__/model.cpython-310.pyc ADDED
Binary file (2.71 kB). View file
 
src/__pycache__/preprocessing.cpython-310.pyc ADDED
Binary file (4.84 kB). View file
 
src/__pycache__/preprocessing.cpython-312.pyc ADDED
Binary file (7.04 kB). View file
 
src/embeddings.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sentence Transformer embeddings for semantic similarity.
3
+ Uses MPS (Apple Silicon GPU) when available.
4
+ """
5
+ import numpy as np
6
+
7
+ _embedding_model = None
8
+
9
+
10
+ def get_embedding_model(device: str = None):
11
+ """Load Sentence Transformer model (cached singleton)."""
12
+ global _embedding_model
13
+ if _embedding_model is not None:
14
+ return _embedding_model
15
+
16
+ try:
17
+ from sentence_transformers import SentenceTransformer
18
+ import torch
19
+
20
+ if device is None:
21
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
22
+ _embedding_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
23
+ return _embedding_model
24
+ except ImportError:
25
+ return None
26
+
27
+
28
+ def embedding_cosine_similarity(q1: str, q2: str, model=None) -> float:
29
+ """
30
+ Compute cosine similarity between question embeddings.
31
+ Returns 0.0 if model unavailable.
32
+ """
33
+ if model is None:
34
+ model = get_embedding_model()
35
+ if model is None:
36
+ return 0.0
37
+
38
+ embeddings = model.encode([q1, q2])
39
+ a, b = embeddings[0], embeddings[1]
40
+ return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-9))
src/feature_engineering.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feature extraction for Quora question pairs.
3
+ """
4
+ import distance
5
+ from fuzzywuzzy import fuzz
6
+ import numpy as np
7
+
8
+ from .preprocessing import preprocess
9
+
10
+ # Use NLTK stopwords (no pickle dependency)
11
+ try:
12
+ from nltk.corpus import stopwords
13
+ STOP_WORDS = set(stopwords.words('english'))
14
+ except LookupError:
15
+ import nltk
16
+ nltk.download('stopwords', quiet=True)
17
+ from nltk.corpus import stopwords
18
+ STOP_WORDS = set(stopwords.words('english'))
19
+
20
+ SAFE_DIV = 0.0001
21
+
22
+
23
+ def _common_words(q1: str, q2: str) -> int:
24
+ w1 = set(word.lower().strip() for word in q1.split())
25
+ w2 = set(word.lower().strip() for word in q2.split())
26
+ return len(w1 & w2)
27
+
28
+
29
+ def _total_words(q1: str, q2: str) -> int:
30
+ w1 = set(word.lower().strip() for word in q1.split())
31
+ w2 = set(word.lower().strip() for word in q2.split())
32
+ return len(w1) + len(w2)
33
+
34
+
35
+ def _fetch_token_features(q1: str, q2: str) -> list:
36
+ token_features = [0.0] * 8
37
+
38
+ q1_tokens = q1.split()
39
+ q2_tokens = q2.split()
40
+
41
+ if len(q1_tokens) == 0 or len(q2_tokens) == 0:
42
+ return token_features
43
+
44
+ q1_words = set(w for w in q1_tokens if w not in STOP_WORDS)
45
+ q2_words = set(w for w in q2_tokens if w not in STOP_WORDS)
46
+ q1_stops = set(w for w in q1_tokens if w in STOP_WORDS)
47
+ q2_stops = set(w for w in q2_tokens if w in STOP_WORDS)
48
+
49
+ common_word_count = len(q1_words & q2_words)
50
+ common_stop_count = len(q1_stops & q2_stops)
51
+ common_token_count = len(set(q1_tokens) & set(q2_tokens))
52
+
53
+ token_features[0] = common_word_count / (min(len(q1_words), len(q2_words)) + SAFE_DIV)
54
+ token_features[1] = common_word_count / (max(len(q1_words), len(q2_words)) + SAFE_DIV)
55
+ token_features[2] = common_stop_count / (min(len(q1_stops), len(q2_stops)) + SAFE_DIV)
56
+ token_features[3] = common_stop_count / (max(len(q1_stops), len(q2_stops)) + SAFE_DIV)
57
+ token_features[4] = common_token_count / (min(len(q1_tokens), len(q2_tokens)) + SAFE_DIV)
58
+ token_features[5] = common_token_count / (max(len(q1_tokens), len(q2_tokens)) + SAFE_DIV)
59
+ token_features[6] = int(q1_tokens[-1] == q2_tokens[-1])
60
+ token_features[7] = int(q1_tokens[0] == q2_tokens[0])
61
+
62
+ return token_features
63
+
64
+
65
+ def _fetch_length_features(q1: str, q2: str) -> list:
66
+ length_features = [0.0] * 3
67
+
68
+ q1_tokens = q1.split()
69
+ q2_tokens = q2.split()
70
+
71
+ if len(q1_tokens) == 0 or len(q2_tokens) == 0:
72
+ return length_features
73
+
74
+ length_features[0] = abs(len(q1_tokens) - len(q2_tokens))
75
+ length_features[1] = (len(q1_tokens) + len(q2_tokens)) / 2
76
+
77
+ # Guard against empty lcsubstrings (IndexError)
78
+ strs = list(distance.lcsubstrings(q1, q2))
79
+ if strs:
80
+ length_features[2] = len(strs[0]) / (min(len(q1), len(q2)) + 1)
81
+ else:
82
+ length_features[2] = 0.0
83
+
84
+ return length_features
85
+
86
+
87
+ def _fetch_fuzzy_features(q1: str, q2: str) -> list:
88
+ return [
89
+ fuzz.QRatio(q1, q2),
90
+ fuzz.partial_ratio(q1, q2),
91
+ fuzz.token_sort_ratio(q1, q2),
92
+ fuzz.token_set_ratio(q1, q2),
93
+ ]
94
+
95
+
96
+ def _jaccard_similarity(q1: str, q2: str) -> float:
97
+ """|intersection| / |union| of word sets."""
98
+ w1 = set(word.lower().strip() for word in q1.split())
99
+ w2 = set(word.lower().strip() for word in q2.split())
100
+ if not w1 and not w2:
101
+ return 0.0
102
+ inter = len(w1 & w2)
103
+ union = len(w1 | w2)
104
+ return inter / union if union else 0.0
105
+
106
+
107
+ def _sentence_length_ratio(q1: str, q2: str) -> float:
108
+ """min(word_count) / max(word_count)."""
109
+ n1, n2 = len(q1.split()), len(q2.split())
110
+ if max(n1, n2) == 0:
111
+ return 0.0
112
+ return min(n1, n2) / max(n1, n2)
113
+
114
+
115
+ def query_point_creator(
116
+ q1: str, q2: str, vectorizer, embedding_model=None
117
+ ) -> np.ndarray:
118
+ """
119
+ Build feature vector for a question pair.
120
+ Requires a fitted CountVectorizer or TfidfVectorizer.
121
+ If embedding_model provided, adds cosine similarity between question embeddings.
122
+ """
123
+ q1 = preprocess(q1)
124
+ q2 = preprocess(q2)
125
+
126
+ input_query = [
127
+ len(q1),
128
+ len(q2),
129
+ len(q1.split()),
130
+ len(q2.split()),
131
+ _common_words(q1, q2),
132
+ _total_words(q1, q2),
133
+ round(_common_words(q1, q2) / (_total_words(q1, q2) + SAFE_DIV), 2),
134
+ ]
135
+ input_query.extend(_fetch_token_features(q1, q2))
136
+ input_query.extend(_fetch_length_features(q1, q2))
137
+ input_query.extend(_fetch_fuzzy_features(q1, q2))
138
+ input_query.append(_jaccard_similarity(q1, q2))
139
+ input_query.append(_sentence_length_ratio(q1, q2))
140
+
141
+ # Sentence Transformer cosine similarity (semantic)
142
+ if embedding_model is not None:
143
+ from .embeddings import embedding_cosine_similarity
144
+ input_query.append(embedding_cosine_similarity(q1, q2, embedding_model))
145
+
146
+ q1_vec = vectorizer.transform([q1]).toarray()
147
+ q2_vec = vectorizer.transform([q2]).toarray()
148
+
149
+ n_handcrafted = len(input_query)
150
+ return np.hstack((np.array(input_query).reshape(1, n_handcrafted), q1_vec, q2_vec))
src/model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model training and evaluation utilities.
3
+ """
4
+ import numpy as np
5
+ from sklearn.model_selection import StratifiedKFold
6
+ from sklearn.base import clone
7
+ from sklearn.metrics import (
8
+ accuracy_score,
9
+ log_loss,
10
+ precision_score,
11
+ recall_score,
12
+ f1_score,
13
+ roc_auc_score,
14
+ confusion_matrix,
15
+ )
16
+
17
+
18
+ def evaluate_model(model, X_test, y_test, prefix: str = ""):
19
+ """
20
+ Compute full evaluation metrics for a binary classifier.
21
+ Returns dict of metrics.
22
+ """
23
+ y_pred = model.predict(X_test)
24
+ y_proba = model.predict_proba(X_test)[:, 1] if hasattr(model, "predict_proba") else None
25
+
26
+ metrics = {
27
+ "accuracy": accuracy_score(y_test, y_pred),
28
+ "precision": precision_score(y_test, y_pred, zero_division=0),
29
+ "recall": recall_score(y_test, y_pred, zero_division=0),
30
+ "f1": f1_score(y_test, y_pred, zero_division=0),
31
+ }
32
+
33
+ if y_proba is not None:
34
+ try:
35
+ metrics["log_loss"] = log_loss(y_test, y_proba)
36
+ except ValueError:
37
+ metrics["log_loss"] = float("nan")
38
+ try:
39
+ metrics["auc_roc"] = roc_auc_score(y_test, y_proba)
40
+ except ValueError:
41
+ metrics["auc_roc"] = float("nan")
42
+
43
+ return metrics
44
+
45
+
46
+ def print_metrics(metrics: dict, prefix: str = ""):
47
+ """Print metrics in a readable format."""
48
+ p = f"{prefix} " if prefix else ""
49
+ print(f"\n--- {p}Metrics ---")
50
+ for name, val in metrics.items():
51
+ if isinstance(val, float) and not np.isnan(val):
52
+ print(f" {name}: {val:.4f}")
53
+ else:
54
+ print(f" {name}: {val}")
55
+ print()
56
+
57
+
58
+ def stratified_cv_evaluate(model, X, y, n_folds: int = 5, random_state: int = 42):
59
+ """
60
+ Run Stratified K-Fold CV and return mean metrics.
61
+ """
62
+ from tqdm import tqdm
63
+
64
+ skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=random_state)
65
+
66
+ fold_metrics = []
67
+ for fold, (train_idx, val_idx) in tqdm(
68
+ enumerate(skf.split(X, y)),
69
+ total=n_folds,
70
+ desc="CV folds",
71
+ unit="fold",
72
+ ):
73
+ X_train, X_val = X[train_idx], X[val_idx]
74
+ y_train, y_val = y[train_idx], y[val_idx]
75
+
76
+ model_clone = clone(model)
77
+ model_clone.fit(X_train, y_train)
78
+ m = evaluate_model(model_clone, X_val, y_val)
79
+ fold_metrics.append(m)
80
+ print(f" Fold {fold + 1}: F1={m['f1']:.4f}, AUC={m.get('auc_roc', 0):.4f}")
81
+
82
+ # Mean across folds
83
+ mean_metrics = {}
84
+ for key in fold_metrics[0]:
85
+ vals = [m[key] for m in fold_metrics if not (isinstance(m[key], float) and np.isnan(m[key]))]
86
+ mean_metrics[key] = np.mean(vals) if vals else float("nan")
87
+
88
+ return mean_metrics, fold_metrics
src/preprocessing.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text preprocessing for Quora question pairs.
3
+ """
4
+ import re
5
+ from bs4 import BeautifulSoup
6
+
7
+ # Module-level constant (avoid recreating on every call)
8
+ CONTRACTIONS = {
9
+ "ain't": "am not",
10
+ "aren't": "are not",
11
+ "can't": "can not",
12
+ "can't've": "can not have",
13
+ "'cause": "because",
14
+ "could've": "could have",
15
+ "couldn't": "could not",
16
+ "couldn't've": "could not have",
17
+ "didn't": "did not",
18
+ "doesn't": "does not",
19
+ "don't": "do not",
20
+ "hadn't": "had not",
21
+ "hadn't've": "had not have",
22
+ "hasn't": "has not",
23
+ "haven't": "have not",
24
+ "he'd": "he would",
25
+ "he'd've": "he would have",
26
+ "he'll": "he will",
27
+ "he'll've": "he will have",
28
+ "he's": "he is",
29
+ "how'd": "how did",
30
+ "how'd'y": "how do you",
31
+ "how'll": "how will",
32
+ "how's": "how is",
33
+ "i'd": "i would",
34
+ "i'd've": "i would have",
35
+ "i'll": "i will",
36
+ "i'll've": "i will have",
37
+ "i'm": "i am",
38
+ "i've": "i have",
39
+ "isn't": "is not",
40
+ "it'd": "it would",
41
+ "it'd've": "it would have",
42
+ "it'll": "it will",
43
+ "it'll've": "it will have",
44
+ "it's": "it is",
45
+ "let's": "let us",
46
+ "ma'am": "madam",
47
+ "mayn't": "may not",
48
+ "might've": "might have",
49
+ "mightn't": "might not",
50
+ "mightn't've": "might not have",
51
+ "must've": "must have",
52
+ "mustn't": "must not",
53
+ "mustn't've": "must not have",
54
+ "needn't": "need not",
55
+ "needn't've": "need not have",
56
+ "o'clock": "of the clock",
57
+ "oughtn't": "ought not",
58
+ "oughtn't've": "ought not have",
59
+ "shan't": "shall not",
60
+ "sha'n't": "shall not",
61
+ "shan't've": "shall not have",
62
+ "she'd": "she would",
63
+ "she'd've": "she would have",
64
+ "she'll": "she will",
65
+ "she'll've": "she will have",
66
+ "she's": "she is",
67
+ "should've": "should have",
68
+ "shouldn't": "should not",
69
+ "shouldn't've": "should not have",
70
+ "so've": "so have",
71
+ "so's": "so as",
72
+ "that'd": "that would",
73
+ "that'd've": "that would have",
74
+ "that's": "that is",
75
+ "there'd": "there would",
76
+ "there'd've": "there would have",
77
+ "there's": "there is",
78
+ "they'd": "they would",
79
+ "they'd've": "they would have",
80
+ "they'll": "they will",
81
+ "they'll've": "they will have",
82
+ "they're": "they are",
83
+ "they've": "they have",
84
+ "to've": "to have",
85
+ "wasn't": "was not",
86
+ "we'd": "we would",
87
+ "we'd've": "we would have",
88
+ "we'll": "we will",
89
+ "we'll've": "we will have",
90
+ "we're": "we are",
91
+ "we've": "we have",
92
+ "weren't": "were not",
93
+ "what'll": "what will",
94
+ "what'll've": "what will have",
95
+ "what're": "what are",
96
+ "what's": "what is",
97
+ "what've": "what have",
98
+ "when's": "when is",
99
+ "when've": "when have",
100
+ "where'd": "where did",
101
+ "where's": "where is",
102
+ "where've": "where have",
103
+ "who'll": "who will",
104
+ "who'll've": "who will have",
105
+ "who's": "who is",
106
+ "who've": "who have",
107
+ "why's": "why is",
108
+ "why've": "why have",
109
+ "will've": "will have",
110
+ "won't": "will not",
111
+ "won't've": "will not have",
112
+ "would've": "would have",
113
+ "wouldn't": "would not",
114
+ "wouldn't've": "would not have",
115
+ "y'all": "you all",
116
+ "y'all'd": "you all would",
117
+ "y'all'd've": "you all would have",
118
+ "y'all're": "you all are",
119
+ "y'all've": "you all have",
120
+ "you'd": "you would",
121
+ "you'd've": "you would have",
122
+ "you'll": "you will",
123
+ "you'll've": "you will have",
124
+ "you're": "you are",
125
+ "you've": "you have",
126
+ }
127
+
128
+
129
+ def preprocess(q: str) -> str:
130
+ """
131
+ Preprocess a question string for feature extraction.
132
+ - Lowercase, strip whitespace
133
+ - Replace special chars ($, %, etc.)
134
+ - Expand contractions
135
+ - Remove HTML tags
136
+ - Remove punctuation
137
+ """
138
+ q = str(q).lower().strip()
139
+
140
+ # Replace certain special characters with their string equivalents
141
+ q = q.replace('%', ' percent')
142
+ q = q.replace('$', ' dollar ')
143
+ q = q.replace('₹', ' rupee ')
144
+ q = q.replace('€', ' euro ')
145
+ q = q.replace('@', ' at ')
146
+
147
+ # The pattern '[math]' appears around 900 times in the whole dataset.
148
+ q = q.replace('[math]', '')
149
+
150
+ # Replacing some numbers with string equivalents
151
+ q = q.replace(',000,000,000 ', 'b ')
152
+ q = q.replace(',000,000 ', 'm ')
153
+ q = q.replace(',000 ', 'k ')
154
+ q = re.sub(r'([0-9]+)000000000', r'\1b', q)
155
+ q = re.sub(r'([0-9]+)000000', r'\1m', q)
156
+ q = re.sub(r'([0-9]+)000', r'\1k', q)
157
+
158
+ # Decontracting words
159
+ q_decontracted = []
160
+ for word in q.split():
161
+ if word in CONTRACTIONS:
162
+ word = CONTRACTIONS[word]
163
+ q_decontracted.append(word)
164
+
165
+ q = ' '.join(q_decontracted)
166
+ q = q.replace("'ve", " have")
167
+ q = q.replace("n't", " not")
168
+ q = q.replace("'re", " are")
169
+ q = q.replace("'ll", " will")
170
+
171
+ # Removing HTML tags (specify parser to avoid warning)
172
+ q = BeautifulSoup(q, "html.parser").get_text()
173
+
174
+ # Remove punctuations
175
+ pattern = re.compile(r'\W')
176
+ q = re.sub(pattern, ' ', q).strip()
177
+
178
+ return q
streamlit-app/.DS_Store ADDED
Binary file (6.15 kB). View file
 
streamlit-app/__pycache__/helper.cpython-310.pyc ADDED
Binary file (5.22 kB). View file
 
streamlit-app/helper.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helper module for Streamlit app.
3
+ Loads model artifacts and delegates to src for feature extraction.
4
+ Supports classical (RF/XGBoost) and transformer (DistilBERT) models.
5
+ """
6
+ import pickle
7
+ import json
8
+ from pathlib import Path
9
+ from typing import Optional, Tuple
10
+
11
+ # Add project root to path for src imports
12
+ _project_root = Path(__file__).resolve().parent.parent
13
+ import sys
14
+
15
+ if str(_project_root) not in sys.path:
16
+ sys.path.insert(0, str(_project_root))
17
+
18
+ from src.feature_engineering import query_point_creator as _query_point_creator
19
+ from src.embeddings import get_embedding_model
20
+
21
+ # Paths
22
+ _models_dir = _project_root / "models"
23
+ _app_dir = Path(__file__).resolve().parent
24
+ _transformer_dir = _models_dir / "transformer"
25
+ _inference_times_path = _models_dir / "inference_times.json"
26
+
27
+
28
+ def _ensure_models_from_hf():
29
+ """Download models from HF Hub if not present and HF_MODEL_REPO is set."""
30
+ import os
31
+ repo_id = os.environ.get("HF_MODEL_REPO")
32
+ if not repo_id or (_models_dir / "model.pkl").exists():
33
+ return
34
+ try:
35
+ from huggingface_hub import snapshot_download
36
+ _models_dir.mkdir(parents=True, exist_ok=True)
37
+ snapshot_download(repo_id=repo_id, local_dir=str(_models_dir))
38
+ except Exception as e:
39
+ print(f"HF Hub download skipped or failed: {e}")
40
+
41
+
42
+ # Try HF Hub download when models missing (for HF Spaces deployment)
43
+ _ensure_models_from_hf()
44
+
45
+ # Classical model artifacts (lazy loaded)
46
+ _classical_model = None
47
+ _classical_cv = None
48
+ _embedding_model = None
49
+
50
+ # Transformer (lazy loaded)
51
+ _transformer_model = None
52
+ _transformer_tokenizer = None
53
+
54
+
55
+ def _get_cv_path():
56
+ return _models_dir / "cv.pkl" if (_models_dir / "cv.pkl").exists() else _app_dir / "cv.pkl"
57
+
58
+
59
+ def _get_model_path():
60
+ return _models_dir / "model.pkl" if (_models_dir / "model.pkl").exists() else _app_dir / "model.pkl"
61
+
62
+
63
+ def get_available_models() -> list:
64
+ """Return list of available model identifiers."""
65
+ available = []
66
+ if _get_model_path().exists() and _get_cv_path().exists():
67
+ available.append("classical")
68
+ if (_transformer_dir / "config.json").exists():
69
+ available.append("transformer")
70
+ return available
71
+
72
+
73
+ def get_inference_times() -> dict:
74
+ """Load benchmark results from models/inference_times.json."""
75
+ if not _inference_times_path.exists():
76
+ return {}
77
+ try:
78
+ with open(_inference_times_path) as f:
79
+ return json.load(f)
80
+ except Exception:
81
+ return {}
82
+
83
+
84
+ def _load_classical():
85
+ global _classical_model, _classical_cv, _embedding_model
86
+ if _classical_model is None:
87
+ _classical_model = pickle.load(open(_get_model_path(), "rb"))
88
+ _classical_cv = pickle.load(open(_get_cv_path(), "rb"))
89
+ _embedding_model = get_embedding_model()
90
+ return _classical_model, _classical_cv, _embedding_model
91
+
92
+
93
+ def _load_transformer():
94
+ global _transformer_model, _transformer_tokenizer
95
+ if _transformer_model is None:
96
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
97
+ import torch
98
+
99
+ _transformer_tokenizer = AutoTokenizer.from_pretrained(str(_transformer_dir))
100
+ _transformer_model = AutoModelForSequenceClassification.from_pretrained(str(_transformer_dir))
101
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
102
+ _transformer_model = _transformer_model.to(device)
103
+ _transformer_model.eval()
104
+ return _transformer_model, _transformer_tokenizer
105
+
106
+
107
+ def query_point_creator(q1: str, q2: str):
108
+ """Build feature vector for classical model. Uses shared src modules + embeddings."""
109
+ _, cv, emb = _load_classical()
110
+ return _query_point_creator(q1, q2, cv, embedding_model=emb)
111
+
112
+
113
+ def predict_classical(q1: str, q2: str) -> Tuple[int, float]:
114
+ """Predict using classical model. Returns (pred, proba)."""
115
+ model, cv, emb = _load_classical()
116
+ feat = _query_point_creator(q1, q2, cv, embedding_model=emb)
117
+ proba = model.predict_proba(feat)[0, 1]
118
+ pred = int(proba >= 0.5)
119
+ return pred, float(proba)
120
+
121
+
122
+ def predict_transformer(q1: str, q2: str) -> Tuple[int, float]:
123
+ """Predict using DistilBERT. Returns (pred, proba)."""
124
+ from src.preprocessing import preprocess
125
+ import torch
126
+
127
+ model, tokenizer = _load_transformer()
128
+ q1_p, q2_p = preprocess(q1), preprocess(q2)
129
+ inputs = tokenizer(
130
+ q1_p, q2_p,
131
+ return_tensors="pt",
132
+ truncation=True,
133
+ max_length=128,
134
+ padding="max_length",
135
+ )
136
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
137
+ with torch.no_grad():
138
+ logits = model(**inputs).logits
139
+ proba = torch.softmax(logits, dim=-1)[0, 1].item()
140
+ pred = 1 if proba >= 0.5 else 0
141
+ return pred, float(proba)
142
+
143
+
144
+ def predict(q1: str, q2: str, model_type: str) -> Tuple[int, float]:
145
+ """Unified prediction. model_type: 'classical' or 'transformer'."""
146
+ if model_type == "classical":
147
+ return predict_classical(q1, q2)
148
+ if model_type == "transformer":
149
+ return predict_transformer(q1, q2)
150
+ raise ValueError(f"Unknown model_type: {model_type}")
151
+
152
+
153
+ def get_model_display_name(model_type: str) -> str:
154
+ """Human-readable name for model selector."""
155
+ return {"classical": "Classical (RF/XGBoost + TF-IDF)", "transformer": "DistilBERT (Transformer)"}.get(
156
+ model_type, model_type
157
+ )