surazbhandari commited on
Commit
adc0ea3
·
verified ·
1 Parent(s): 3da5f3c

Restore complete repository state (revert docs-only push)

Browse files
.DS_Store ADDED
Binary file (10.2 kB). View file
 
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ models/mini/model.pt filter=lfs diff=lfs merge=lfs -text
2
+ models/mini/model.safetensors filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MODEL_CARD.md CHANGED
@@ -55,16 +55,11 @@ from src.inference import EmbeddingInference
55
  # Load -- just like sentence-transformers!
56
  model = EmbeddingInference.from_pretrained("surazbhandari/miniembed")
57
 
58
- # 1. Similarity
59
  score = model.similarity("Machine learning is great", "AI is wonderful")
60
  print(f"Similarity: {score:.4f}") # 0.4287
61
 
62
- # 2. Normal Embeddings
63
- embeddings = model.encode(["Machine learning is great", "AI is wonderful"])
64
- import numpy as np
65
- manual_score = np.dot(embeddings[0], embeddings[1]) # Dot product = Cosine Similarity
66
-
67
- # 3. Semantic Search
68
  docs = ["Python is great for AI", "I love pizza", "Neural networks learn patterns"]
69
  results = model.search("deep learning frameworks", docs, top_k=2)
70
  for r in results:
@@ -72,7 +67,7 @@ for r in results:
72
  # [0.498] Neural networks learn patterns
73
  # [0.413] Python is great for AI
74
 
75
- # 4. Clustering
76
  result = model.cluster_texts(["ML is cool", "Pizza is food", "AI rocks"], n_clusters=2)
77
  # Cluster 1: ['Pizza is food']
78
  # Cluster 2: ['ML is cool', 'AI rocks']
 
55
  # Load -- just like sentence-transformers!
56
  model = EmbeddingInference.from_pretrained("surazbhandari/miniembed")
57
 
58
+ # Similarity
59
  score = model.similarity("Machine learning is great", "AI is wonderful")
60
  print(f"Similarity: {score:.4f}") # 0.4287
61
 
62
+ # Semantic Search
 
 
 
 
 
63
  docs = ["Python is great for AI", "I love pizza", "Neural networks learn patterns"]
64
  results = model.search("deep learning frameworks", docs, top_k=2)
65
  for r in results:
 
67
  # [0.498] Neural networks learn patterns
68
  # [0.413] Python is great for AI
69
 
70
+ # Clustering
71
  result = model.cluster_texts(["ML is cool", "Pizza is food", "AI rocks"], n_clusters=2)
72
  # Cluster 1: ['Pizza is food']
73
  # Cluster 2: ['ML is cool', 'AI rocks']
README.md CHANGED
@@ -82,25 +82,15 @@ from src.inference import EmbeddingInference
82
 
83
  model = EmbeddingInference.from_pretrained("models/mini")
84
 
85
- # 1. Similarity
86
  score = model.similarity("Machine learning is great", "AI is wonderful")
87
  print(f"Similarity: {score:.4f}") # 0.4287
88
 
89
- # 2. Normal Embeddings
90
- embeddings = model.encode(["Machine learning is great", "AI is wonderful"])
91
- import numpy as np
92
- manual_score = np.dot(embeddings[0], embeddings[1]) # Dot product = Cosine Similarity
93
-
94
- # 3. Semantic Search
95
  docs = ["Python is great for AI", "I love pizza", "Neural networks learn patterns"]
96
  results = model.search("deep learning frameworks", docs, top_k=2)
97
  for r in results:
98
  print(f" [{r['score']:.3f}] {r['text']}")
99
-
100
- # 4. Clustering
101
- result = model.cluster_texts(["ML is cool", "Pizza is food", "AI rocks"], n_clusters=2)
102
- # Cluster 1: ['Pizza is food']
103
- # Cluster 2: ['ML is cool', 'AI rocks']
104
  ```
105
 
106
  For full Hugging Face integration, ensure you have `huggingface_hub` installed:
 
82
 
83
  model = EmbeddingInference.from_pretrained("models/mini")
84
 
85
+ # Similarity
86
  score = model.similarity("Machine learning is great", "AI is wonderful")
87
  print(f"Similarity: {score:.4f}") # 0.4287
88
 
89
+ # Semantic Search
 
 
 
 
 
90
  docs = ["Python is great for AI", "I love pizza", "Neural networks learn patterns"]
91
  results = model.search("deep learning frameworks", docs, top_k=2)
92
  for r in results:
93
  print(f" [{r['score']:.3f}] {r['text']}")
 
 
 
 
 
94
  ```
95
 
96
  For full Hugging Face integration, ensure you have `huggingface_hub` installed:
data/sample_data.jsonl ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {"query": "how to train an embedding model", "passage": "Training an embedding model involves using contrastive learning on query-passage pairs.", "source": "sample"}
2
+ {"query": "what is a transformer", "passage": "The Transformer is a deep learning model that uses self-attention mechanisms to process sequence data.", "source": "sample"}
3
+ {"query": "nike air max 90", "passage": "Men's Nike Air Max 90 Casual Shoes in Black and White.", "source": "sample"}
4
+ {"query": "samsung galaxy s21", "passage": "Samsung Galaxy S21 5G 128GB Unlocked Smartphone - Phantom Gray.", "source": "sample"}
5
+ {"query": "best winter coats", "passage": "The North Face Gotham Jacket III is one of the warmest winter parkas for heavy snow.", "source": "sample"}
6
+ {"query": "python programming for beginners", "passage": "Learn Python with this comprehensive guide covering variables, loops, and functions.", "source": "sample"}
7
+ {"query": "benefits of meditation", "passage": "Meditation can reduce stress, improve concentration, and increase happiness.", "source": "sample"}
8
+ {"query": "how to bake chocolate cake", "passage": "Whisk eggs and sugar, then fold in flour and melted chocolate for a perfect moist cake.", "source": "sample"}
9
+ {"query": "what is machine learning", "passage": "Machine learning is a field of AI that allows systems to learn patterns from data without explicit programming.", "source": "sample"}
10
+ {"query": "running shoes for flat feet", "passage": "Brooks Adrenaline GTS 22 provides excellent stability and support for runners with low arches.", "source": "sample"}
demo.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MiniEmbed - Interactive Demo
3
+ ================================
4
+ Explore the embedding model's capabilities through a Streamlit dashboard.
5
+
6
+ Features:
7
+ - Pairwise text similarity (cosine distance)
8
+ - Semantic document search with ranked results
9
+ - Unsupervised text clustering via K-Means
10
+ - Raw embedding vector inspection and visualization
11
+ - Bulk CSV-to-CSV record matching
12
+
13
+ Run: streamlit run demo.py
14
+ """
15
+
16
+ import streamlit as st
17
+ import numpy as np
18
+ import pandas as pd
19
+ import os
20
+ import sys
21
+ import io
22
+
23
+ # Add src to path
24
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
25
+
26
+ from src.inference import EmbeddingInference, EmbeddingModelManager
27
+
28
+ # ============================================================================
29
+ # PAGE CONFIG
30
+ # ============================================================================
31
+
32
+ st.set_page_config(
33
+ page_title="MiniEmbed Demo",
34
+ page_icon="M",
35
+ layout="wide"
36
+ )
37
+
38
+ # Custom CSS
39
+ st.markdown("""
40
+ <style>
41
+ .main-header {
42
+ font-size: 2.5rem;
43
+ font-weight: 700;
44
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
45
+ -webkit-background-clip: text;
46
+ -webkit-text-fill-color: transparent;
47
+ text-align: center;
48
+ margin-bottom: 1rem;
49
+ }
50
+ .sub-header {
51
+ text-align: center;
52
+ color: #888;
53
+ margin-bottom: 2rem;
54
+ }
55
+ .result-box {
56
+ background: rgba(100, 100, 100, 0.1);
57
+ border-radius: 10px;
58
+ padding: 1rem;
59
+ margin: 0.5rem 0;
60
+ color: inherit;
61
+ }
62
+ .high-score { border-left: 4px solid #28a745; background: rgba(40, 167, 69, 0.1); }
63
+ .medium-score { border-left: 4px solid #ffc107; background: rgba(255, 193, 7, 0.1); }
64
+ .low-score { border-left: 4px solid #dc3545; background: rgba(220, 53, 69, 0.1); }
65
+ .score-text { font-weight: bold; }
66
+ </style>
67
+ """, unsafe_allow_html=True)
68
+
69
+ # ============================================================================
70
+ # LOAD MODEL
71
+ # ============================================================================
72
+
73
+ @st.cache_resource
74
+ def load_model(model_name):
75
+ """Load the embedding model from disk."""
76
+ model_dir = f"models/{model_name}"
77
+ if model_name == "Legacy (model/)":
78
+ model_dir = "model"
79
+ return EmbeddingInference.from_pretrained(model_dir)
80
+
81
+
82
+ # Header
83
+ st.markdown('<h1 class="main-header">MiniEmbed Demo</h1>', unsafe_allow_html=True)
84
+ st.markdown('<p class="sub-header">Explore semantic similarity, search, clustering, and bulk matching</p>', unsafe_allow_html=True)
85
+
86
+ # -----------------------------------------------------------------------------
87
+ # Model Selection
88
+ # -----------------------------------------------------------------------------
89
+ available_models = EmbeddingModelManager.list_models()
90
+ if os.path.exists("model/model.pt"):
91
+ available_models.append("Legacy (model/)")
92
+
93
+ if not available_models:
94
+ st.error("No models found. Train a model first or place weights in models/mini/model.pt.")
95
+ st.info("Models should be located in the `models/` directory (e.g., `models/mini/`).")
96
+ st.stop()
97
+
98
+ selected_model_name = st.sidebar.selectbox(
99
+ "Select Model",
100
+ available_models,
101
+ index=0,
102
+ help="Select which trained model to load for inference."
103
+ )
104
+
105
+ model = load_model(selected_model_name)
106
+
107
+ if model is None:
108
+ st.error("Model not found. Please train the model first.")
109
+ st.stop()
110
+
111
+ # Model info
112
+ with st.expander("Model Info", expanded=False):
113
+ st.markdown("""
114
+ This panel shows the architecture of the currently loaded model.
115
+ - **Embedding Dim**: The size of each output vector (higher = more expressive).
116
+ - **Layers**: Number of Transformer encoder layers stacked in the model.
117
+ - **Vocab Size**: Total number of unique tokens the model can recognize.
118
+ """)
119
+ col1, col2, col3 = st.columns(3)
120
+ with col1:
121
+ st.metric("Embedding Dim", model.model.d_model)
122
+ with col2:
123
+ st.metric("Layers", len(model.model.layers))
124
+ with col3:
125
+ st.metric("Vocab Size", len(model.tokenizer.word_to_id))
126
+
127
+ # ============================================================================
128
+ # TABS
129
+ # ============================================================================
130
+
131
+ tab1, tab2, tab3, tab4, tab5 = st.tabs([
132
+ "Similarity",
133
+ "Semantic Search",
134
+ "Clustering",
135
+ "Encode Text",
136
+ "CSV Matcher"
137
+ ])
138
+
139
+ # ============================================================================
140
+ # TAB 1: SIMILARITY
141
+ # ============================================================================
142
+
143
+ with tab1:
144
+ st.markdown("### Pairwise Text Similarity")
145
+ st.markdown("""
146
+ Enter two texts to compute their **cosine similarity** (range: 0 to 1).
147
+ The model encodes each text into a 256-dimensional vector and measures
148
+ the angular distance between them. A score close to 1.0 means the texts
149
+ are semantically equivalent; a score near 0.0 means they are unrelated.
150
+ """)
151
+
152
+ col1, col2 = st.columns(2)
153
+
154
+ with col1:
155
+ text1 = st.text_area(
156
+ "Text 1",
157
+ "Machine learning is a branch of artificial intelligence",
158
+ height=100,
159
+ key="sim_text1"
160
+ )
161
+
162
+ with col2:
163
+ text2 = st.text_area(
164
+ "Text 2",
165
+ "AI systems can learn patterns from data",
166
+ height=100,
167
+ key="sim_text2"
168
+ )
169
+
170
+ if st.button("Compute Similarity", type="primary", key="sim_btn"):
171
+ if text1 and text2:
172
+ with st.spinner("Computing..."):
173
+ similarity = model.similarity(text1, text2)
174
+
175
+ if similarity > 0.7:
176
+ color = "#28a745"
177
+ label = "Very Similar"
178
+ elif similarity > 0.4:
179
+ color = "#ffc107"
180
+ label = "Somewhat Similar"
181
+ else:
182
+ color = "#dc3545"
183
+ label = "Not Similar"
184
+
185
+ st.markdown(f"""
186
+ <div style="text-align: center; padding: 2rem;">
187
+ <div style="font-size: 4rem; font-weight: bold; color: {color};">
188
+ {similarity:.3f}
189
+ </div>
190
+ <div style="font-size: 1.2rem; color: {color};">
191
+ {label}
192
+ </div>
193
+ </div>
194
+ """, unsafe_allow_html=True)
195
+
196
+ # Example pairs
197
+ st.markdown("---")
198
+ st.markdown("#### Example Pairs")
199
+ st.markdown("These pairs demonstrate how the model distinguishes related from unrelated content:")
200
+
201
+ examples = [
202
+ ("Python is a programming language", "Java is used for software development"),
203
+ ("The cat sat on the mat", "A feline rested on the rug"),
204
+ ("Machine learning is fascinating", "I love eating pizza"),
205
+ ]
206
+
207
+ for t1, t2 in examples:
208
+ similarity = model.similarity(t1, t2)
209
+
210
+ if similarity > 0.5:
211
+ css_class = "high-score"
212
+ elif similarity > 0.3:
213
+ css_class = "medium-score"
214
+ else:
215
+ css_class = "low-score"
216
+
217
+ st.markdown(f"""
218
+ <div class="result-box {css_class}">
219
+ <strong>{similarity:.3f}</strong> | "{t1}" vs "{t2}"
220
+ </div>
221
+ """, unsafe_allow_html=True)
222
+
223
+ # ============================================================================
224
+ # TAB 2: SEMANTIC SEARCH
225
+ # ============================================================================
226
+
227
+ with tab2:
228
+ st.markdown("### Semantic Document Search")
229
+ st.markdown("""
230
+ Enter a natural-language query. The model encodes your query and all
231
+ documents into the same vector space, then ranks documents by cosine
232
+ similarity. This finds **meaning-based** matches, not just keyword overlap.
233
+ """)
234
+
235
+ default_docs = """Python is a high-level programming language
236
+ Machine learning algorithms learn patterns from data
237
+ The weather today is sunny and warm
238
+ Neural networks are inspired by the human brain
239
+ JavaScript is used for web development
240
+ Deep learning has transformed computer vision
241
+ Cats are popular pets around the world
242
+ TensorFlow and PyTorch are ML frameworks
243
+ The stock market had a volatile day
244
+ Natural language processing understands text"""
245
+
246
+ query = st.text_input(
247
+ "Search Query",
248
+ "How do AI systems learn from examples?",
249
+ key="search_query"
250
+ )
251
+
252
+ documents_text = st.text_area(
253
+ "Documents (one per line)",
254
+ default_docs,
255
+ height=200,
256
+ key="search_docs"
257
+ )
258
+
259
+ top_k = st.slider("Number of results", 1, 10, 5, key="search_topk")
260
+
261
+ if st.button("Search", type="primary", key="search_btn"):
262
+ documents = [d.strip() for d in documents_text.split('\n') if d.strip()]
263
+
264
+ if query and documents:
265
+ with st.spinner("Searching..."):
266
+ results = model.search(query, documents, top_k=top_k)
267
+
268
+ st.markdown("### Results")
269
+ st.markdown("Documents ranked by semantic relevance to your query:")
270
+
271
+ for r in results:
272
+ score = r['score']
273
+ if score > 0.6:
274
+ indicator = "[HIGH]"
275
+ css_class = "high-score"
276
+ elif score > 0.4:
277
+ indicator = "[MED]"
278
+ css_class = "medium-score"
279
+ else:
280
+ indicator = "[LOW]"
281
+ css_class = "low-score"
282
+
283
+ st.markdown(f"""
284
+ <div class="result-box {css_class}">
285
+ <strong>{indicator} #{r['rank']}</strong> (score: {score:.4f})<br>
286
+ {r['text']}
287
+ </div>
288
+ """, unsafe_allow_html=True)
289
+
290
+ # ============================================================================
291
+ # TAB 3: CLUSTERING
292
+ # ============================================================================
293
+
294
+ with tab3:
295
+ st.markdown("### Unsupervised Text Clustering")
296
+ st.markdown("""
297
+ The model encodes each text into a dense vector. K-Means clustering
298
+ then groups these vectors by proximity in the embedding space.
299
+ Texts that are semantically similar end up in the same cluster,
300
+ even if they share no common words.
301
+ """)
302
+
303
+ default_cluster_texts = """Python programming language
304
+ Machine learning algorithms
305
+ Deep learning neural networks
306
+ JavaScript web development
307
+ Cats and dogs as pets
308
+ Pizza and pasta Italian food
309
+ Sunny weather today
310
+ Rainy day forecast
311
+ Stock market trends
312
+ Financial news update"""
313
+
314
+ cluster_texts = st.text_area(
315
+ "Texts to cluster (one per line)",
316
+ default_cluster_texts,
317
+ height=200,
318
+ key="cluster_texts"
319
+ )
320
+
321
+ n_clusters = st.slider("Number of clusters", 2, 10, 3, key="n_clusters")
322
+
323
+ if st.button("Run Clustering", type="primary", key="cluster_btn"):
324
+ texts = [t.strip() for t in cluster_texts.split('\n') if t.strip()]
325
+
326
+ if len(texts) >= n_clusters:
327
+ with st.spinner("Clustering..."):
328
+ result = model.cluster_texts(texts, n_clusters=n_clusters)
329
+
330
+ st.markdown("### Cluster Assignments")
331
+ st.markdown("Each group contains texts that the model considers semantically related:")
332
+
333
+ colors = ["#667eea", "#28a745", "#ffc107", "#dc3545", "#17a2b8",
334
+ "#6f42c1", "#fd7e14", "#20c997", "#e83e8c", "#6c757d"]
335
+
336
+ for cluster_id in sorted(result['texts_by_cluster'].keys()):
337
+ cluster_texts_list = result['texts_by_cluster'][cluster_id]
338
+ color = colors[cluster_id % len(colors)]
339
+
340
+ st.markdown(f"""
341
+ <div style="background: {color}15; border-left: 4px solid {color};
342
+ padding: 1rem; border-radius: 5px; margin: 0.5rem 0;">
343
+ <strong style="color: {color};">Cluster {cluster_id + 1}</strong>
344
+ ({len(cluster_texts_list)} texts)
345
+ </div>
346
+ """, unsafe_allow_html=True)
347
+
348
+ for text in cluster_texts_list:
349
+ st.markdown(f" - {text}")
350
+ else:
351
+ st.warning(f"Need at least {n_clusters} texts to create {n_clusters} clusters.")
352
+
353
+ # ============================================================================
354
+ # TAB 4: ENCODE TEXT
355
+ # ============================================================================
356
+
357
+ with tab4:
358
+ st.markdown("### Raw Embedding Inspector")
359
+ st.markdown("""
360
+ Convert any text into its dense vector representation. The output is a
361
+ 256-dimensional float vector that is **L2-normalized** (unit length = 1.0).
362
+ This is the same representation used internally for similarity and search.
363
+ """)
364
+
365
+ encode_text = st.text_area(
366
+ "Text to encode",
367
+ "Machine learning is a fascinating field of study.",
368
+ height=100,
369
+ key="encode_text"
370
+ )
371
+
372
+ if st.button("Encode", type="primary", key="encode_btn"):
373
+ if encode_text:
374
+ with st.spinner("Encoding..."):
375
+ embedding = model.encode(encode_text)
376
+
377
+ st.markdown("### Embedding Vector")
378
+
379
+ col1, col2, col3 = st.columns(3)
380
+ with col1:
381
+ st.metric("Dimensions", embedding.shape[1])
382
+ with col2:
383
+ st.metric("L2 Norm", f"{np.linalg.norm(embedding[0]):.4f}")
384
+ with col3:
385
+ st.metric("Mean Value", f"{embedding[0].mean():.4f}")
386
+
387
+ st.markdown("#### First 20 values:")
388
+ st.code(str(embedding[0][:20].round(4).tolist()))
389
+
390
+ st.markdown("#### Value Distribution")
391
+ st.markdown("A well-trained model produces a roughly Gaussian distribution centered near zero:")
392
+ import plotly.express as px
393
+ fig = px.histogram(
394
+ x=embedding[0],
395
+ nbins=50,
396
+ title="Embedding Value Distribution",
397
+ labels={'x': 'Value', 'y': 'Count'}
398
+ )
399
+ fig.update_layout(showlegend=False)
400
+ st.plotly_chart(fig, width="stretch")
401
+
402
+ # ============================================================================
403
+ # TAB 5: CSV MATCHER
404
+ # ============================================================================
405
+
406
+ with tab5:
407
+ st.markdown("### Bulk CSV Record Matcher")
408
+ st.markdown("""
409
+ Upload two CSV files and match rows across them using semantic similarity.
410
+ This is useful for:
411
+ - **Product deduplication** across e-commerce platforms
412
+ - **Record linkage** between databases with inconsistent naming
413
+ - **Cross-platform mapping** (e.g., matching supplier catalogs to your inventory)
414
+
415
+ The model encodes the selected text column from each CSV, then ranks
416
+ every row in CSV 2 against each row in CSV 1 by cosine similarity.
417
+ """)
418
+
419
+ col1, col2 = st.columns(2)
420
+
421
+ with col1:
422
+ st.markdown("#### Upload CSV 1 (Queries)")
423
+ file1 = st.file_uploader("Upload primary CSV", type=['csv'], key="csv_file_1")
424
+
425
+ with col2:
426
+ st.markdown("#### Upload CSV 2 (Knowledge Base)")
427
+ file2 = st.file_uploader("Upload secondary CSV", type=['csv'], key="csv_file_2")
428
+
429
+ if file1 and file2:
430
+ df1 = pd.read_csv(file1)
431
+ df2 = pd.read_csv(file2)
432
+
433
+ st.markdown("---")
434
+ col_m1, col_m2 = st.columns(2)
435
+
436
+ with col_m1:
437
+ col1_name = st.selectbox("Select column to match from CSV 1", df1.columns, key="col1_sel")
438
+
439
+ with col_m2:
440
+ col2_name = st.selectbox("Select column to search in CSV 2", df2.columns, key="col2_sel")
441
+
442
+ col_p1, col_p2 = st.columns(2)
443
+ with col_p1:
444
+ top_n_candidates = st.slider("Step 1: Top candidates to fetch", 1, 50, 10, help="Initial semantic search depth")
445
+ with col_p2:
446
+ top_m_final = st.slider("Step 2: Top matches to keep", 1, 10, 3, help="Final number of matches per row")
447
+
448
+ if st.button("Start Bulk Matching", type="primary"):
449
+ progress_bar = st.progress(0)
450
+ status_text = st.empty()
451
+
452
+ queries = df1[col1_name].fillna("").astype(str).tolist()
453
+ corpus = df2[col2_name].fillna("").astype(str).tolist()
454
+
455
+ status_text.text("Encoding search corpus (CSV 2)...")
456
+ corpus_embs = model.encode(corpus, batch_size=128)
457
+ progress_bar.progress(20)
458
+
459
+ status_text.text("Encoding queries (CSV 1)...")
460
+ query_embs = model.encode(queries, batch_size=128)
461
+ progress_bar.progress(50)
462
+
463
+ status_text.text("Computing similarities and mapping...")
464
+ similarities = np.dot(query_embs, corpus_embs.T)
465
+ progress_bar.progress(80)
466
+
467
+ all_results = []
468
+ for i in range(len(queries)):
469
+ row_scores = similarities[i]
470
+ top_indices = np.argsort(row_scores)[::-1][:top_m_final]
471
+
472
+ res_row = df1.iloc[i].to_dict()
473
+ for rank, idx in enumerate(top_indices, 1):
474
+ res_row[f'Match_{rank}_{col2_name}'] = corpus[idx]
475
+ res_row[f'Match_{rank}_Score'] = round(float(row_scores[idx]), 4)
476
+ all_results.append(res_row)
477
+
478
+ res_df = pd.DataFrame(all_results)
479
+
480
+ progress_bar.progress(100)
481
+ status_text.text("Matching complete.")
482
+
483
+ st.markdown("### Results Preview")
484
+ st.dataframe(res_df.head(50), width="stretch")
485
+
486
+ output = io.StringIO()
487
+ res_df.to_csv(output, index=False)
488
+ csv_string = output.getvalue()
489
+
490
+ st.download_button(
491
+ label="Download Full Results CSV",
492
+ data=csv_string,
493
+ file_name="semantic_matching_results.csv",
494
+ mime="text/csv",
495
+ )
496
+ else:
497
+ st.info("Upload both CSV files to begin matching.")
498
+
499
+
500
+ # ============================================================================
501
+ # FOOTER
502
+ # ============================================================================
503
+
504
+ st.markdown("---")
505
+ st.markdown("""
506
+ <div style="text-align: center; color: #666; padding: 1rem;">
507
+ <strong>MiniEmbed</strong> | Lightweight Text Embeddings |
508
+ <a href="https://github.com/bhandarisuraz/miniembed">GitHub</a>
509
+ </div>
510
+ """, unsafe_allow_html=True)
examples/basic_usage.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic Usage Example
3
+ ===================
4
+ Demonstrates encoding texts and computing similarity using MiniEmbed.
5
+
6
+ This script shows the three core operations:
7
+ 1. Encoding raw text into dense vectors
8
+ 2. Computing pairwise similarity between two texts
9
+ 3. Building a full similarity matrix across sets of texts
10
+ """
11
+
12
+ import sys
13
+ sys.path.insert(0, '..')
14
+
15
+ from src.inference import EmbeddingInference
16
+
17
+
18
+ def main():
19
+ print("=" * 60)
20
+ print("MiniEmbed - Basic Usage Example")
21
+ print("=" * 60)
22
+
23
+ # Load the model
24
+ print("\nLoading model...")
25
+ model = EmbeddingInference.from_pretrained("../models/mini")
26
+ print("Model loaded.\n")
27
+
28
+ # -------------------------------------------------------------------------
29
+ # Example 1: Encode texts
30
+ # -------------------------------------------------------------------------
31
+ print("-" * 40)
32
+ print("Example 1: Encoding Texts")
33
+ print("-" * 40)
34
+
35
+ texts = [
36
+ "Machine learning is a branch of artificial intelligence",
37
+ "Deep learning uses neural networks with many layers",
38
+ "I love eating pizza on weekends"
39
+ ]
40
+
41
+ embeddings = model.encode(texts)
42
+ print(f"Input: {len(texts)} texts")
43
+ print(f"Output: {embeddings.shape}") # (3, 256)
44
+
45
+ # -------------------------------------------------------------------------
46
+ # Example 2: Compute similarity
47
+ # -------------------------------------------------------------------------
48
+ print("\n" + "-" * 40)
49
+ print("Example 2: Computing Similarity")
50
+ print("-" * 40)
51
+
52
+ pairs = [
53
+ ("Machine learning is great", "AI is wonderful"),
54
+ ("Machine learning is great", "I love pizza"),
55
+ ("The cat sat on the mat", "A feline rested on the rug"),
56
+ ]
57
+
58
+ for text1, text2 in pairs:
59
+ similarity = model.similarity(text1, text2)
60
+ tag = "MATCH" if similarity > 0.5 else " LOW"
61
+ print(f" [{tag}] {similarity:.4f} | '{text1}' vs '{text2}'")
62
+
63
+ # -------------------------------------------------------------------------
64
+ # Example 3: Pairwise similarity matrix
65
+ # -------------------------------------------------------------------------
66
+ print("\n" + "-" * 40)
67
+ print("Example 3: Pairwise Similarity Matrix")
68
+ print("-" * 40)
69
+
70
+ texts_a = ["Machine learning", "Deep learning", "Natural language"]
71
+ texts_b = ["AI models", "Neural networks", "Text processing"]
72
+
73
+ similarity_matrix = model.pairwise_similarity(texts_a, texts_b)
74
+
75
+ print("\nSimilarity Matrix:")
76
+ print(" ", " ".join(f"{t[:10]:>10}" for t in texts_b))
77
+ for i, text in enumerate(texts_a):
78
+ row = " ".join(f"{similarity_matrix[i, j]:>10.4f}" for j in range(len(texts_b)))
79
+ print(f"{text[:12]:>12}: {row}")
80
+
81
+ print("\nDone.")
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
examples/clustering.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text Clustering Example
3
+ =======================
4
+ Demonstrates how to cluster texts by semantic similarity using MiniEmbed.
5
+
6
+ The model encodes each text into a dense vector. K-Means clustering then
7
+ groups these vectors by proximity in the embedding space, even if the texts
8
+ share no common words.
9
+ """
10
+
11
+ import sys
12
+ sys.path.insert(0, '..')
13
+
14
+ from src.inference import EmbeddingInference
15
+
16
+
17
+ def main():
18
+ print("=" * 60)
19
+ print("MiniEmbed - Text Clustering Example")
20
+ print("=" * 60)
21
+
22
+ # Load the model
23
+ print("\nLoading model...")
24
+ model = EmbeddingInference.from_pretrained("../models/mini")
25
+ print("Model loaded.\n")
26
+
27
+ # -------------------------------------------------------------------------
28
+ # Text collection (mixed topics)
29
+ # -------------------------------------------------------------------------
30
+ texts = [
31
+ # Technology
32
+ "Python is a versatile programming language",
33
+ "Machine learning models learn from data",
34
+ "JavaScript is used for web development",
35
+ "Neural networks process information like the brain",
36
+ "Software engineering involves designing systems",
37
+
38
+ # Food
39
+ "Pizza is my favorite Italian dish",
40
+ "Sushi is a traditional Japanese cuisine",
41
+ "Tacos are delicious Mexican street food",
42
+ "Pasta with marinara sauce is comforting",
43
+ "Ramen noodles are popular in Japan",
44
+
45
+ # Sports
46
+ "Football is the most popular sport worldwide",
47
+ "Basketball requires teamwork and skill",
48
+ "Tennis is an exciting individual sport",
49
+ "Swimming is great for cardiovascular health",
50
+ "Soccer World Cup attracts billions of viewers",
51
+
52
+ # Nature
53
+ "Mountains offer breathtaking scenic views",
54
+ "Oceans cover most of the Earth's surface",
55
+ "Forests are home to diverse wildlife",
56
+ "Rivers provide fresh water to ecosystems",
57
+ "Deserts have extreme temperature variations",
58
+ ]
59
+
60
+ print(f"Text Collection: {len(texts)} texts (4 topics)")
61
+
62
+ # -------------------------------------------------------------------------
63
+ # Cluster texts
64
+ # -------------------------------------------------------------------------
65
+ print("\nClustering texts into 4 groups...")
66
+
67
+ result = model.cluster_texts(texts, n_clusters=4)
68
+
69
+ # -------------------------------------------------------------------------
70
+ # Display results
71
+ # -------------------------------------------------------------------------
72
+ print("\n" + "=" * 60)
73
+ print("Clustering Results")
74
+ print("=" * 60)
75
+
76
+ for cluster_id in sorted(result['texts_by_cluster'].keys()):
77
+ cluster_texts = result['texts_by_cluster'][cluster_id]
78
+
79
+ print(f"\n Cluster {cluster_id + 1} ({len(cluster_texts)} texts)")
80
+ print("-" * 40)
81
+
82
+ for text in cluster_texts:
83
+ print(f" - {text}")
84
+
85
+ # -------------------------------------------------------------------------
86
+ # Evaluate clustering (simple check)
87
+ # -------------------------------------------------------------------------
88
+ print("\n" + "=" * 60)
89
+ print("Clustering Analysis")
90
+ print("=" * 60)
91
+
92
+ # Expected groupings (approximate)
93
+ expected = {
94
+ "Technology": texts[0:5],
95
+ "Food": texts[5:10],
96
+ "Sports": texts[10:15],
97
+ "Nature": texts[15:20],
98
+ }
99
+
100
+ print("\nLabels assigned to each text:")
101
+ for i, (text, label) in enumerate(zip(texts, result['labels'])):
102
+ topic = list(expected.keys())[i // 5]
103
+ print(f" [{label}] ({topic}) {text[:50]}...")
104
+
105
+ print("\nDone.")
106
+
107
+
108
+ if __name__ == "__main__":
109
+ main()
examples/semantic_search.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Semantic Search Example
3
+ =======================
4
+ Demonstrates how to use MiniEmbed for document retrieval.
5
+
6
+ The model encodes a query and a corpus of documents into the same vector space,
7
+ then ranks documents by cosine similarity to the query. This finds results based
8
+ on meaning, not keyword overlap.
9
+ """
10
+
11
+ import sys
12
+ sys.path.insert(0, '..')
13
+
14
+ from src.inference import EmbeddingInference
15
+
16
+
17
+ def main():
18
+ print("=" * 60)
19
+ print("MiniEmbed - Semantic Search Example")
20
+ print("=" * 60)
21
+
22
+ # Load the model
23
+ print("\nLoading model...")
24
+ model = EmbeddingInference.from_pretrained("../models/mini")
25
+ print("Model loaded.\n")
26
+
27
+ # -------------------------------------------------------------------------
28
+ # Document collection
29
+ # -------------------------------------------------------------------------
30
+ documents = [
31
+ "Python is a high-level programming language known for its simplicity",
32
+ "Machine learning algorithms can learn patterns from data",
33
+ "The weather today is sunny with a high of 75 degrees",
34
+ "Neural networks are computational models inspired by the brain",
35
+ "JavaScript is widely used for web development",
36
+ "Deep learning has revolutionized computer vision and NLP",
37
+ "Cats are popular pets known for their independence",
38
+ "TensorFlow and PyTorch are popular deep learning frameworks",
39
+ "The stock market showed strong gains today",
40
+ "Natural language processing helps computers understand text"
41
+ ]
42
+
43
+ print(f"Document Collection: {len(documents)} documents")
44
+ for i, doc in enumerate(documents, 1):
45
+ print(f" {i}. {doc[:60]}...")
46
+
47
+ # -------------------------------------------------------------------------
48
+ # Search queries
49
+ # -------------------------------------------------------------------------
50
+ queries = [
51
+ "How do AI systems learn from examples?",
52
+ "What programming language is good for beginners?",
53
+ "Tell me about artificial neural networks",
54
+ ]
55
+
56
+ print("\n" + "=" * 60)
57
+ print("Search Results")
58
+ print("=" * 60)
59
+
60
+ for query in queries:
61
+ print(f"\n Query: \"{query}\"")
62
+ print("-" * 50)
63
+
64
+ results = model.search(query, documents, top_k=3)
65
+
66
+ for r in results:
67
+ score = r['score']
68
+ if score > 0.6:
69
+ tag = "[HIGH]"
70
+ elif score > 0.4:
71
+ tag = "[ MED]"
72
+ else:
73
+ tag = "[ LOW]"
74
+
75
+ print(f" {tag} #{r['rank']} (score: {score:.4f})")
76
+ print(f" {r['text']}")
77
+
78
+ # -------------------------------------------------------------------------
79
+ # Interactive search (optional)
80
+ # -------------------------------------------------------------------------
81
+ print("\n" + "=" * 60)
82
+ print("Interactive Search")
83
+ print("=" * 60)
84
+ print("Enter your own queries (type 'quit' to exit):\n")
85
+
86
+ while True:
87
+ try:
88
+ query = input(" Query: ").strip()
89
+ if query.lower() in ['quit', 'exit', 'q']:
90
+ break
91
+ if not query:
92
+ continue
93
+
94
+ results = model.search(query, documents, top_k=3)
95
+
96
+ print("\n Results:")
97
+ for r in results:
98
+ print(f" - [{r['score']:.3f}] {r['text'][:70]}...")
99
+ print()
100
+
101
+ except (KeyboardInterrupt, EOFError):
102
+ break
103
+
104
+ print("\nDone.")
105
+
106
+
107
+ if __name__ == "__main__":
108
+ main()
models/large/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # MiniEmbed - Large
2
+
3
+ Full-scale variant for maximum accuracy on complex semantic tasks.
4
+
5
+ Coming soon...
models/medium/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # MiniEmbed - Medium
2
+
3
+ Balanced variant offering higher accuracy with moderate compute requirements.
4
+
5
+ Coming soon...
models/mini/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 30000,
3
+ "d_model": 256,
4
+ "num_heads": 4,
5
+ "num_layers": 4,
6
+ "d_ff": 1024,
7
+ "max_seq_len": 128,
8
+ "pad_token_id": 0,
9
+ "size_name": "mini"
10
+ }
models/mini/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b039b35819583641fc877c8aa3ce28aadc098d0ef79757c7bb8683141d2cde21
3
+ size 43508675
models/mini/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f31f7aa0a6ea31a7cb1c0e68cf3a914b589866eef6b4379314de8ce64a139c8
3
+ size 43495744
models/mini/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
models/mini/training_info.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epochs": 10,
3
+ "total_steps": 149130,
4
+ "final_loss": 0.07483314797282219,
5
+ "training_samples": 3817707,
6
+ "batch_size": 256,
7
+ "learning_rate": 0.0002,
8
+ "date": "2026-02-13 22:53:59",
9
+ "training_time_minutes": 2940.0035917321843
10
+ }
models/product/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # MiniEmbed - Product
2
+
3
+ Fine-tuned variant of Mini, specialized for high-accuracy product matching.
4
+
5
+ Coming soon...
models/small/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # MiniEmbed - Small
2
+
3
+ A larger variant with increased capacity for general-purpose embeddings.
4
+
5
+ Coming soon...
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core
2
+ torch>=2.0.0
3
+ numpy>=1.21.0
4
+ tqdm>=4.64.0
5
+
6
+ # Demo UI
7
+ streamlit>=1.30.0
8
+ plotly>=5.0.0
9
+
10
+ # Optional (for clustering, CSV processing, & Benchmarking)
11
+ scikit-learn>=1.0.0
12
+ pandas>=2.0.0
13
+ psutil>=5.9.0
14
+ sentence-transformers>=2.2.0
src/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MiniEmbed - Lightweight Text Embedding Model
3
+ """
4
+
5
+ from .model import MiniTransformerEmbedding
6
+ from .tokenizer import SimpleTokenizer
7
+ from .inference import EmbeddingInference, EmbeddingModelManager
8
+
9
+ __version__ = "1.0.0"
10
+ __all__ = [
11
+ "MiniTransformerEmbedding",
12
+ "SimpleTokenizer",
13
+ "EmbeddingInference",
14
+ "EmbeddingModelManager"
15
+ ]
src/inference.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Saving & Inference Module
3
+ ===================================
4
+ Easy-to-use API for loading and running inference with the embedding model.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import json
11
+ import os
12
+ from pathlib import Path
13
+ from typing import List, Dict, Union, Tuple
14
+
15
+ from .model import MiniTransformerEmbedding
16
+ from .tokenizer import SimpleTokenizer
17
+
18
+
19
+ class EmbeddingModelManager:
20
+ """
21
+ Handles saving and loading the embedding model.
22
+
23
+ Save structure:
24
+ model_dir/
25
+ ├── config.json # Model architecture config
26
+ ├── model.pt # Model weights
27
+ ├── tokenizer.json # Vocabulary
28
+ └── training_info.json # Training metadata (optional)
29
+ """
30
+
31
+ @staticmethod
32
+ def save_model(
33
+ model: MiniTransformerEmbedding,
34
+ tokenizer: SimpleTokenizer,
35
+ save_dir: str,
36
+ training_info: dict = None
37
+ ):
38
+ """
39
+ Save model, tokenizer, and config for later use.
40
+
41
+ Args:
42
+ model: Trained MiniTransformerEmbedding
43
+ tokenizer: SimpleTokenizer with vocabulary
44
+ save_dir: Directory to save model
45
+ training_info: Optional training metadata
46
+ """
47
+ save_dir = Path(save_dir)
48
+ save_dir.mkdir(parents=True, exist_ok=True)
49
+
50
+ # 1. Save model config
51
+ config = {
52
+ 'vocab_size': len(tokenizer.word_to_id),
53
+ 'd_model': model.d_model,
54
+ 'num_heads': model.layers[0].attention.num_heads,
55
+ 'num_layers': len(model.layers),
56
+ 'd_ff': model.layers[0].feed_forward.linear1.out_features,
57
+ 'max_seq_len': model.positional_encoding.pe.size(1),
58
+ 'pad_token_id': model.pad_token_id,
59
+ 'size_name': save_dir.name # Use folder name as size name
60
+ }
61
+
62
+ with open(save_dir / 'config.json', 'w') as f:
63
+ json.dump(config, f, indent=2)
64
+
65
+ # 2. Save model weights
66
+ torch.save(model.state_dict(), save_dir / 'model.pt')
67
+
68
+ # 3. Save tokenizer vocabulary
69
+ tokenizer.save(str(save_dir / 'tokenizer.json'))
70
+
71
+ # 4. Save training info (optional)
72
+ if training_info:
73
+ with open(save_dir / 'training_info.json', 'w') as f:
74
+ json.dump(training_info, f, indent=2)
75
+
76
+ print(f"Model saved to: {save_dir}")
77
+
78
+ @staticmethod
79
+ def load_model(model_dir: str, device: str = None) -> Tuple[MiniTransformerEmbedding, SimpleTokenizer]:
80
+ """
81
+ Load model and tokenizer from a local directory or HuggingFace repo.
82
+
83
+ Args:
84
+ model_dir: Local directory path OR HuggingFace repo ID
85
+ (e.g., "surazbhandari/miniembed")
86
+ device: Device to load model on ('cpu', 'cuda', 'mps')
87
+
88
+ Returns:
89
+ (model, tokenizer) tuple
90
+ """
91
+ # Auto-detect HuggingFace repo ID (contains "/" but is not a local path)
92
+ if '/' in model_dir and not os.path.exists(model_dir):
93
+ model_dir = EmbeddingModelManager._download_from_hub(model_dir)
94
+
95
+ model_dir = Path(model_dir)
96
+
97
+ if device is None:
98
+ if torch.cuda.is_available():
99
+ device = 'cuda'
100
+ elif torch.backends.mps.is_available():
101
+ device = 'mps'
102
+ else:
103
+ device = 'cpu'
104
+
105
+ # 1. Load config
106
+ config_path = model_dir / 'config.json'
107
+
108
+ with open(config_path, 'r') as f:
109
+ config = json.load(f)
110
+
111
+ # 2. Load tokenizer
112
+ tokenizer_path = model_dir / 'tokenizer.json'
113
+
114
+ tokenizer = SimpleTokenizer(vocab_size=config['vocab_size'])
115
+ tokenizer.load(str(tokenizer_path))
116
+
117
+ # 3. Create and load model
118
+ model = MiniTransformerEmbedding(
119
+ vocab_size=config['vocab_size'],
120
+ d_model=config['d_model'],
121
+ num_heads=config['num_heads'],
122
+ num_layers=config['num_layers'],
123
+ d_ff=config['d_ff'],
124
+ max_seq_len=config['max_seq_len'],
125
+ pad_token_id=config['pad_token_id']
126
+ )
127
+
128
+ # Load weights (prefer safetensors)
129
+ st_path = model_dir / 'model.safetensors'
130
+ pt_path = model_dir / 'model.pt'
131
+
132
+ if st_path.exists():
133
+ from safetensors.torch import load_file
134
+ state_dict = load_file(str(st_path), device=device)
135
+ elif pt_path.exists():
136
+ state_dict = torch.load(pt_path, map_location=device, weights_only=True)
137
+ else:
138
+ raise FileNotFoundError(f"Neither model.safetensors nor model.pt found in {model_dir}")
139
+
140
+ model.load_state_dict(state_dict)
141
+ model = model.to(device)
142
+ model.eval()
143
+
144
+ return model, tokenizer
145
+
146
+ @staticmethod
147
+ def _download_from_hub(repo_id: str) -> str:
148
+ """
149
+ Download model files from a HuggingFace repository.
150
+
151
+ Args:
152
+ repo_id: HuggingFace repo ID (e.g., "surazbhandari/miniembed")
153
+
154
+ Returns:
155
+ Local directory path containing the downloaded files.
156
+ """
157
+ try:
158
+ from huggingface_hub import hf_hub_download, snapshot_download
159
+ except ImportError:
160
+ raise ImportError(
161
+ "huggingface_hub is required to download models from HuggingFace. "
162
+ "Install it with: pip install huggingface_hub"
163
+ )
164
+
165
+ # Download the full model snapshot
166
+ local_dir = snapshot_download(
167
+ repo_id=repo_id,
168
+ allow_patterns=["config.json", "model.safetensors", "model.pt", "tokenizer.json", "training_info.json"],
169
+ )
170
+
171
+ return local_dir
172
+
173
+ @staticmethod
174
+ def list_models(base_dir: str = "models") -> List[str]:
175
+ """
176
+ List available model names in the base directory.
177
+
178
+ Returns:
179
+ List of directory names containing valid models
180
+ """
181
+ path = Path(base_dir)
182
+ if not path.exists():
183
+ return []
184
+ return sorted([d.name for d in path.iterdir() if d.is_dir() and (d / "model.pt").exists()])
185
+
186
+ class EmbeddingInference:
187
+ """
188
+ High-level inference API for the embedding model.
189
+
190
+ Usage:
191
+ # From local directory
192
+ model = EmbeddingInference.from_pretrained("./models/mini")
193
+
194
+ # From HuggingFace
195
+ model = EmbeddingInference.from_pretrained("surazbhandari/miniembed")
196
+
197
+ # Encode texts
198
+ embeddings = model.encode(["Hello world", "Machine learning"])
199
+
200
+ # Compute similarity
201
+ score = model.similarity("query", "document")
202
+
203
+ # Semantic search
204
+ results = model.search("python programming", documents)
205
+ """
206
+
207
+ def __init__(
208
+ self,
209
+ model: MiniTransformerEmbedding,
210
+ tokenizer: SimpleTokenizer,
211
+ device: str = 'cpu',
212
+ max_length: int = 64
213
+ ):
214
+ self.model = model
215
+ self.tokenizer = tokenizer
216
+ self.device = device
217
+ self.max_length = max_length
218
+ self.model.eval()
219
+
220
+ @classmethod
221
+ def from_pretrained(cls, model_dir: str, device: str = None):
222
+ """
223
+ Load model from a local directory or HuggingFace repo ID.
224
+
225
+ Args:
226
+ model_dir: Local path (e.g., "models/mini") or
227
+ HuggingFace repo ID (e.g., "surazbhandari/miniembed")
228
+ device: Device to load on ('cpu', 'cuda', 'mps'). Auto-detected if None.
229
+ """
230
+ model, tokenizer = EmbeddingModelManager.load_model(model_dir, device)
231
+ if device is None:
232
+ device = next(model.parameters()).device.type
233
+ return cls(model, tokenizer, device)
234
+
235
+ def encode(
236
+ self,
237
+ texts: Union[str, List[str]],
238
+ batch_size: int = 32,
239
+ show_progress: bool = False
240
+ ) -> np.ndarray:
241
+ """
242
+ Encode texts to embeddings.
243
+
244
+ Args:
245
+ texts: Single text or list of texts
246
+ batch_size: Batch size for encoding
247
+ show_progress: Show progress bar
248
+
249
+ Returns:
250
+ numpy array of shape (n_texts, d_model)
251
+ """
252
+ if isinstance(texts, str):
253
+ texts = [texts]
254
+
255
+ all_embeddings = []
256
+
257
+ # Process in batches
258
+ for i in range(0, len(texts), batch_size):
259
+ batch_texts = texts[i:i + batch_size]
260
+
261
+ # Tokenize
262
+ encodings = [
263
+ self.tokenizer.encode(t, self.max_length)
264
+ for t in batch_texts
265
+ ]
266
+
267
+ input_ids = torch.stack([e['input_ids'] for e in encodings]).to(self.device)
268
+ attention_mask = torch.stack([e['attention_mask'] for e in encodings]).to(self.device)
269
+
270
+ # Encode
271
+ with torch.no_grad():
272
+ embeddings = self.model.encode(input_ids, attention_mask)
273
+
274
+ all_embeddings.append(embeddings.cpu().numpy())
275
+
276
+ return np.vstack(all_embeddings)
277
+
278
+ def similarity(self, text1: str, text2: str) -> float:
279
+ """Compute cosine similarity between two texts."""
280
+ emb1 = self.encode(text1)
281
+ emb2 = self.encode(text2)
282
+ return float(np.dot(emb1[0], emb2[0]))
283
+
284
+ def pairwise_similarity(self, texts1: List[str], texts2: List[str]) -> np.ndarray:
285
+ """
286
+ Compute pairwise similarity between two lists.
287
+
288
+ Returns:
289
+ Matrix of shape (len(texts1), len(texts2))
290
+ """
291
+ emb1 = self.encode(texts1)
292
+ emb2 = self.encode(texts2)
293
+ return np.dot(emb1, emb2.T)
294
+
295
+ def search(
296
+ self,
297
+ query: str,
298
+ documents: List[str],
299
+ top_k: int = 5
300
+ ) -> List[Dict]:
301
+ """
302
+ Semantic search: Find most similar documents to query.
303
+
304
+ Args:
305
+ query: Search query
306
+ documents: List of documents to search
307
+ top_k: Number of results to return
308
+
309
+ Returns:
310
+ List of dicts with 'text', 'score', 'rank'
311
+ """
312
+ query_emb = self.encode(query)
313
+ doc_embs = self.encode(documents)
314
+
315
+ # Compute similarities
316
+ scores = np.dot(doc_embs, query_emb.T).flatten()
317
+
318
+ # Get top-k indices
319
+ top_indices = np.argsort(scores)[::-1][:top_k]
320
+
321
+ results = []
322
+ for rank, idx in enumerate(top_indices, 1):
323
+ results.append({
324
+ 'rank': rank,
325
+ 'text': documents[idx],
326
+ 'score': float(scores[idx]),
327
+ 'index': int(idx)
328
+ })
329
+
330
+ return results
331
+
332
+ def cluster_texts(self, texts: List[str], n_clusters: int = 5) -> Dict:
333
+ """
334
+ Cluster texts by embedding similarity.
335
+
336
+ Returns:
337
+ Dict with 'labels' and 'texts_by_cluster'
338
+ """
339
+ from sklearn.cluster import KMeans
340
+
341
+ embeddings = self.encode(texts)
342
+
343
+ kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
344
+ labels = kmeans.fit_predict(embeddings)
345
+
346
+ return {
347
+ 'labels': labels.tolist(),
348
+ 'centroids': kmeans.cluster_centers_,
349
+ 'texts_by_cluster': {
350
+ i: [texts[j] for j in range(len(texts)) if labels[j] == i]
351
+ for i in range(n_clusters)
352
+ }
353
+ }
src/model.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mini-Transformer Embedding Model
3
+ ====================================
4
+ A lightweight transformer encoder for generating text embeddings.
5
+ Built from scratch using PyTorch.
6
+
7
+ Architecture:
8
+ - Token Embeddings + Sinusoidal Positional Encoding
9
+ - N Transformer Encoder Layers (Pre-LayerNorm)
10
+ - Multi-Head Self-Attention
11
+ - Position-wise Feed-Forward Networks
12
+ - Mean Pooling + L2 Normalization
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import math
19
+ from typing import Optional
20
+
21
+
22
+ class SinusoidalPositionalEncoding(nn.Module):
23
+ """
24
+ Sinusoidal positional encoding from "Attention Is All You Need".
25
+
26
+ Adds position information to token embeddings using sin/cos functions
27
+ at different frequencies, allowing the model to understand token order.
28
+ """
29
+
30
+ def __init__(self, d_model: int, max_seq_len: int = 512, dropout: float = 0.1):
31
+ super().__init__()
32
+ self.dropout = nn.Dropout(p=dropout)
33
+
34
+ # Create positional encoding matrix [max_seq_len, d_model]
35
+ pe = torch.zeros(max_seq_len, d_model)
36
+ position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
37
+
38
+ # Compute division term for frequencies
39
+ div_term = torch.exp(
40
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
41
+ )
42
+
43
+ # Apply sin to even indices, cos to odd indices
44
+ pe[:, 0::2] = torch.sin(position * div_term)
45
+ pe[:, 1::2] = torch.cos(position * div_term)
46
+
47
+ # Add batch dimension and register as buffer (not a parameter)
48
+ pe = pe.unsqueeze(0) # [1, max_seq_len, d_model]
49
+ self.register_buffer('pe', pe)
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ """
53
+ Args:
54
+ x: Tensor of shape [batch_size, seq_len, d_model]
55
+ Returns:
56
+ Tensor with positional encoding added
57
+ """
58
+ x = x + self.pe[:, :x.size(1), :]
59
+ return self.dropout(x)
60
+
61
+
62
+ class MultiHeadSelfAttention(nn.Module):
63
+ """
64
+ Multi-Head Self-Attention mechanism.
65
+
66
+ Allows the model to jointly attend to information from different
67
+ representation subspaces at different positions.
68
+ """
69
+
70
+ def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
71
+ super().__init__()
72
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
73
+
74
+ self.d_model = d_model
75
+ self.num_heads = num_heads
76
+ self.d_k = d_model // num_heads # Dimension per head
77
+
78
+ # Linear projections for Q, K, V
79
+ self.W_q = nn.Linear(d_model, d_model)
80
+ self.W_k = nn.Linear(d_model, d_model)
81
+ self.W_v = nn.Linear(d_model, d_model)
82
+
83
+ # Output projection
84
+ self.W_o = nn.Linear(d_model, d_model)
85
+
86
+ self.dropout = nn.Dropout(dropout)
87
+ self.scale = math.sqrt(self.d_k)
88
+
89
+ def forward(
90
+ self,
91
+ x: torch.Tensor,
92
+ attention_mask: Optional[torch.Tensor] = None
93
+ ) -> torch.Tensor:
94
+ """
95
+ Args:
96
+ x: Input tensor [batch_size, seq_len, d_model]
97
+ attention_mask: Optional mask [batch_size, seq_len]
98
+ Returns:
99
+ Output tensor [batch_size, seq_len, d_model]
100
+ """
101
+ batch_size, seq_len, _ = x.size()
102
+
103
+ # Linear projections
104
+ Q = self.W_q(x) # [batch, seq, d_model]
105
+ K = self.W_k(x)
106
+ V = self.W_v(x)
107
+
108
+ # Reshape to [batch, num_heads, seq, d_k]
109
+ Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
110
+ K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
111
+ V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
112
+
113
+ # Scaled dot-product attention
114
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
115
+ # scores: [batch, num_heads, seq, seq]
116
+
117
+ # Apply attention mask (for padding)
118
+ if attention_mask is not None:
119
+ # Expand mask: [batch, 1, 1, seq]
120
+ mask = attention_mask.unsqueeze(1).unsqueeze(2)
121
+ scores = scores.masked_fill(mask == 0, float('-inf'))
122
+
123
+ # Softmax and dropout
124
+ attn_weights = F.softmax(scores, dim=-1)
125
+ attn_weights = self.dropout(attn_weights)
126
+
127
+ # Apply attention to values
128
+ context = torch.matmul(attn_weights, V)
129
+ # context: [batch, num_heads, seq, d_k]
130
+
131
+ # Reshape back: [batch, seq, d_model]
132
+ context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
133
+
134
+ # Output projection
135
+ output = self.W_o(context)
136
+
137
+ return output
138
+
139
+
140
+ class PositionwiseFeedForward(nn.Module):
141
+ """
142
+ Position-wise Feed-Forward Network.
143
+
144
+ Two linear transformations with a GELU activation in between.
145
+ Applied to each position separately and identically.
146
+ """
147
+
148
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
149
+ super().__init__()
150
+ self.linear1 = nn.Linear(d_model, d_ff)
151
+ self.linear2 = nn.Linear(d_ff, d_model)
152
+ self.dropout = nn.Dropout(dropout)
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ """
156
+ Args:
157
+ x: Input tensor [batch_size, seq_len, d_model]
158
+ Returns:
159
+ Output tensor [batch_size, seq_len, d_model]
160
+ """
161
+ x = self.linear1(x)
162
+ x = F.gelu(x)
163
+ x = self.dropout(x)
164
+ x = self.linear2(x)
165
+ return x
166
+
167
+
168
+ class TransformerEncoderLayer(nn.Module):
169
+ """
170
+ Single Transformer Encoder Layer with Pre-LayerNorm.
171
+
172
+ Components:
173
+ 1. Multi-Head Self-Attention with residual connection
174
+ 2. Position-wise Feed-Forward with residual connection
175
+
176
+ Uses Pre-LayerNorm for better training stability.
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ d_model: int,
182
+ num_heads: int,
183
+ d_ff: int,
184
+ dropout: float = 0.1
185
+ ):
186
+ super().__init__()
187
+
188
+ # Layer normalization
189
+ self.norm1 = nn.LayerNorm(d_model)
190
+ self.norm2 = nn.LayerNorm(d_model)
191
+
192
+ # Sub-layers
193
+ self.attention = MultiHeadSelfAttention(d_model, num_heads, dropout)
194
+ self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
195
+
196
+ # Dropout for residual connections
197
+ self.dropout = nn.Dropout(dropout)
198
+
199
+ def forward(
200
+ self,
201
+ x: torch.Tensor,
202
+ attention_mask: Optional[torch.Tensor] = None
203
+ ) -> torch.Tensor:
204
+ """
205
+ Args:
206
+ x: Input tensor [batch_size, seq_len, d_model]
207
+ attention_mask: Optional mask [batch_size, seq_len]
208
+ Returns:
209
+ Output tensor [batch_size, seq_len, d_model]
210
+ """
211
+ # Pre-norm attention block
212
+ normed = self.norm1(x)
213
+ attn_output = self.attention(normed, attention_mask)
214
+ x = x + self.dropout(attn_output) # Residual connection
215
+
216
+ # Pre-norm feed-forward block
217
+ normed = self.norm2(x)
218
+ ff_output = self.feed_forward(normed)
219
+ x = x + self.dropout(ff_output) # Residual connection
220
+
221
+ return x
222
+
223
+
224
+ class MiniTransformerEmbedding(nn.Module):
225
+ """
226
+ Mini-Transformer Embedding Model.
227
+
228
+ Converts variable-length text sequences into fixed-size dense vectors
229
+ suitable for semantic similarity, search, and clustering tasks.
230
+
231
+ Architecture:
232
+ 1. Token Embedding Layer (vocab → d_model)
233
+ 2. Sinusoidal Positional Encoding
234
+ 3. N Transformer Encoder Layers
235
+ 4. Mean Pooling (sequence → single vector)
236
+ 5. L2 Normalization (for cosine similarity)
237
+ """
238
+
239
+ def __init__(
240
+ self,
241
+ vocab_size: int = 30000,
242
+ d_model: int = 256,
243
+ num_heads: int = 4,
244
+ num_layers: int = 4,
245
+ d_ff: int = 1024,
246
+ max_seq_len: int = 128,
247
+ dropout: float = 0.1,
248
+ pad_token_id: int = 0
249
+ ):
250
+ super().__init__()
251
+
252
+ self.d_model = d_model
253
+ self.pad_token_id = pad_token_id
254
+
255
+ # Token embedding
256
+ self.token_embedding = nn.Embedding(
257
+ vocab_size, d_model, padding_idx=pad_token_id
258
+ )
259
+
260
+ # Positional encoding
261
+ self.positional_encoding = SinusoidalPositionalEncoding(
262
+ d_model, max_seq_len, dropout
263
+ )
264
+
265
+ # Transformer encoder layers
266
+ self.layers = nn.ModuleList([
267
+ TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
268
+ for _ in range(num_layers)
269
+ ])
270
+
271
+ # Final layer norm
272
+ self.final_norm = nn.LayerNorm(d_model)
273
+
274
+ # Initialize weights
275
+ self._init_weights()
276
+
277
+ def _init_weights(self):
278
+ """Initialize weights using Xavier/Glorot initialization."""
279
+ for module in self.modules():
280
+ if isinstance(module, nn.Linear):
281
+ nn.init.xavier_uniform_(module.weight)
282
+ if module.bias is not None:
283
+ nn.init.zeros_(module.bias)
284
+ elif isinstance(module, nn.Embedding):
285
+ nn.init.normal_(module.weight, mean=0, std=0.02)
286
+ if module.padding_idx is not None:
287
+ nn.init.zeros_(module.weight[module.padding_idx])
288
+
289
+ def forward(
290
+ self,
291
+ input_ids: torch.Tensor,
292
+ attention_mask: Optional[torch.Tensor] = None
293
+ ) -> torch.Tensor:
294
+ """
295
+ Forward pass through the encoder.
296
+
297
+ Args:
298
+ input_ids: Token IDs [batch_size, seq_len]
299
+ attention_mask: Mask for padding [batch_size, seq_len]
300
+
301
+ Returns:
302
+ Token-level representations [batch_size, seq_len, d_model]
303
+ """
304
+ # Token embeddings with scaling
305
+ x = self.token_embedding(input_ids) * math.sqrt(self.d_model)
306
+
307
+ # Add positional encoding
308
+ x = self.positional_encoding(x)
309
+
310
+ # Pass through transformer layers
311
+ for layer in self.layers:
312
+ x = layer(x, attention_mask)
313
+
314
+ # Final layer norm
315
+ x = self.final_norm(x)
316
+
317
+ return x
318
+
319
+ def encode(
320
+ self,
321
+ input_ids: torch.Tensor,
322
+ attention_mask: Optional[torch.Tensor] = None
323
+ ) -> torch.Tensor:
324
+ """
325
+ Encode input tokens to a single embedding vector per sequence.
326
+
327
+ Uses mean pooling over non-padded tokens, followed by L2 normalization.
328
+
329
+ Args:
330
+ input_ids: Token IDs [batch_size, seq_len]
331
+ attention_mask: Mask for padding [batch_size, seq_len]
332
+
333
+ Returns:
334
+ Normalized embeddings [batch_size, d_model]
335
+ """
336
+ # Get token-level representations
337
+ token_embeddings = self.forward(input_ids, attention_mask)
338
+
339
+ # Mean pooling
340
+ if attention_mask is not None:
341
+ # Expand mask for broadcasting: [batch, seq, 1]
342
+ mask_expanded = attention_mask.unsqueeze(-1).float()
343
+
344
+ # Sum of embeddings (masked)
345
+ sum_embeddings = torch.sum(token_embeddings * mask_expanded, dim=1)
346
+
347
+ # Count of non-padded tokens
348
+ sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
349
+
350
+ # Mean
351
+ embeddings = sum_embeddings / sum_mask
352
+ else:
353
+ # Simple mean over all tokens
354
+ embeddings = torch.mean(token_embeddings, dim=1)
355
+
356
+ # L2 normalization for cosine similarity
357
+ embeddings = F.normalize(embeddings, p=2, dim=1)
358
+
359
+ return embeddings
src/tokenizer.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple Word-Level Tokenizer
3
+ ==============================
4
+ A basic tokenizer for demonstration purposes.
5
+ Converts text to token IDs with special tokens.
6
+ """
7
+
8
+ import re
9
+ import json
10
+ from typing import Dict, List, Optional
11
+ from collections import Counter
12
+ from tqdm import tqdm
13
+
14
+
15
+ class SimpleTokenizer:
16
+ """
17
+ A simple word-level tokenizer with special tokens.
18
+
19
+ Special Tokens:
20
+ - [PAD]: Padding token (id=0)
21
+ - [UNK]: Unknown token (id=1)
22
+ - [CLS]: Classification token (id=2)
23
+ - [SEP]: Separator token (id=3)
24
+ """
25
+
26
+ def __init__(self, vocab_size: int = 30000):
27
+ self.vocab_size = vocab_size
28
+
29
+ # Special tokens
30
+ self.special_tokens = {
31
+ '[PAD]': 0,
32
+ '[UNK]': 1,
33
+ '[CLS]': 2,
34
+ '[SEP]': 3,
35
+ }
36
+
37
+ # Word to ID mapping
38
+ self.word_to_id: Dict[str, int] = dict(self.special_tokens)
39
+ self.id_to_word: Dict[int, str] = {v: k for k, v in self.special_tokens.items()}
40
+
41
+ # Special token IDs
42
+ self.pad_token_id = 0
43
+ self.unk_token_id = 1
44
+ self.cls_token_id = 2
45
+ self.sep_token_id = 3
46
+
47
+ def _tokenize(self, text: str) -> List[str]:
48
+ """
49
+ Split text into tokens (simple word-level tokenization).
50
+
51
+ Args:
52
+ text: Input text string
53
+
54
+ Returns:
55
+ List of tokens
56
+ """
57
+ # Lowercase and basic cleaning
58
+ text = text.lower().strip()
59
+
60
+ # Simple word tokenization with punctuation handling
61
+ tokens = re.findall(r'\b\w+\b|[^\w\s]', text)
62
+
63
+ return tokens
64
+
65
+ def build_vocab(self, texts: List[str], min_freq: int = 2):
66
+ """
67
+ Build vocabulary from a list of texts.
68
+
69
+ Args:
70
+ texts: List of text strings
71
+ min_freq: Minimum frequency for a word to be included
72
+ """
73
+ # Count word frequencies
74
+ word_counts = Counter()
75
+
76
+ for text in tqdm(texts, desc="Building vocabulary"):
77
+ tokens = self._tokenize(text)
78
+ word_counts.update(tokens)
79
+
80
+ # Sort by frequency and take top vocab_size - special_tokens
81
+ max_words = self.vocab_size - len(self.special_tokens)
82
+
83
+ sorted_words = sorted(
84
+ word_counts.items(),
85
+ key=lambda x: x[1],
86
+ reverse=True
87
+ )
88
+
89
+ # Add words to vocabulary
90
+ for word, count in sorted_words[:max_words]:
91
+ if count >= min_freq and word not in self.word_to_id:
92
+ idx = len(self.word_to_id)
93
+ self.word_to_id[word] = idx
94
+ self.id_to_word[idx] = word
95
+
96
+ print(f"Vocabulary size: {len(self.word_to_id)}")
97
+
98
+ def encode(self, text: str, max_length: int = 128) -> Dict:
99
+ """
100
+ Encode text to token IDs with attention mask.
101
+
102
+ Args:
103
+ text: Input text string
104
+ max_length: Maximum sequence length
105
+
106
+ Returns:
107
+ Dictionary with 'input_ids' and 'attention_mask' tensors
108
+ """
109
+ import torch
110
+
111
+ # Tokenize
112
+ tokens = self._tokenize(text)
113
+
114
+ # Convert to IDs (with CLS and SEP)
115
+ token_ids = [self.cls_token_id]
116
+
117
+ for token in tokens[:max_length - 2]: # Reserve space for CLS and SEP
118
+ token_id = self.word_to_id.get(token, self.unk_token_id)
119
+ token_ids.append(token_id)
120
+
121
+ token_ids.append(self.sep_token_id)
122
+
123
+ # Create attention mask
124
+ attention_mask = [1] * len(token_ids)
125
+
126
+ # Pad to max_length
127
+ padding_length = max_length - len(token_ids)
128
+ token_ids.extend([self.pad_token_id] * padding_length)
129
+ attention_mask.extend([0] * padding_length)
130
+
131
+ return {
132
+ 'input_ids': torch.tensor(token_ids, dtype=torch.long),
133
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.long)
134
+ }
135
+
136
+ def decode(self, token_ids: List[int]) -> str:
137
+ """
138
+ Decode token IDs back to text.
139
+
140
+ Args:
141
+ token_ids: List of token IDs
142
+
143
+ Returns:
144
+ Decoded text string
145
+ """
146
+ tokens = []
147
+ for idx in token_ids:
148
+ if idx in [self.pad_token_id, self.cls_token_id, self.sep_token_id]:
149
+ continue
150
+ token = self.id_to_word.get(idx, '[UNK]')
151
+ tokens.append(token)
152
+ return ' '.join(tokens)
153
+
154
+ def save(self, path: str):
155
+ """Save tokenizer vocabulary to JSON file."""
156
+ data = {
157
+ 'vocab_size': self.vocab_size,
158
+ 'word_to_id': self.word_to_id,
159
+ }
160
+ with open(path, 'w') as f:
161
+ json.dump(data, f, indent=2)
162
+
163
+ def load(self, path: str):
164
+ """Load tokenizer vocabulary from JSON file."""
165
+ with open(path, 'r') as f:
166
+ data = json.load(f)
167
+
168
+ self.vocab_size = data['vocab_size']
169
+ self.word_to_id = data['word_to_id']
170
+ self.id_to_word = {int(v): k for k, v in self.word_to_id.items()}
171
+
172
+ def __len__(self) -> int:
173
+ return len(self.word_to_id)