surazbhandari commited on
Commit
c792dcb
·
verified ·
1 Parent(s): 1197aec

Sync from GitHub Actions

Browse files
Files changed (9) hide show
  1. README.md +74 -3
  2. demo.py +216 -0
  3. model.safetensors +3 -0
  4. requirements.txt +15 -0
  5. src/__init__.py +15 -0
  6. src/inference.py +338 -0
  7. src/model.py +359 -0
  8. src/tokenizer.py +162 -0
  9. tokenizer.json +0 -0
README.md CHANGED
@@ -1,3 +1,74 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MiniEmbed Product Matcher
3
+ emoji: ""
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ pinned: false
7
+ license: mit
8
+ library_name: generic
9
+ tags:
10
+ - embeddings
11
+ - product-matching
12
+ ---
13
+
14
+ # MiniEmbed: Product Matching Model
15
+
16
+ This is a specialized version of **MiniEmbed**, fine-tuned exclusively for **high-accuracy product matching** (entity resolution).
17
+
18
+ Unlike general-purpose embedding models, this model is designed to determine if two product listings—often with different titles, specifications, or formatting—refer to the **exact same physical item**.
19
+
20
+ ## Use Case
21
+
22
+ **Cross-Catalog Product Matching**
23
+ * **Scenario**: You have a catalog (Site A) and want to find matching products in a competitor's catalog (Site B).
24
+ * **Challenge**: Titles differ ("iPhone 14 128GB" vs "Apple iPhone 14 Midnight 128GB"), specs are formatted differently, and noise/distractors exist.
25
+ * **Solution**: This model maps semantically identical products to the same vector space, ignoring irrelevant noise while paying attention to critical specs (GB, Model Number, Color).
26
+
27
+ ## Interactive Demo
28
+
29
+ This repository includes a **Streamlit** app to demonstrate the matching capability.
30
+
31
+ To run locally:
32
+
33
+ ```bash
34
+ pip install -r requirements.txt
35
+ streamlit run app.py
36
+ ```
37
+
38
+ ## Model Architecture
39
+
40
+ * **Type**: Transformer Bi-Encoder (BERT-style)
41
+ * **Parameters**: ~10.8M (Mini)
42
+ * **Dimensions**: 256
43
+ * **Max Sequence Length**: 128 tokens
44
+ * **Format**: `SafeTensors` (Hugging Face ready)
45
+
46
+ ## Usage
47
+
48
+ You can use the provided `src` library to run inference in your own Python scripts:
49
+
50
+ ```python
51
+ from src.inference import EmbeddingInference
52
+
53
+ # Load model from current directory
54
+ model = EmbeddingInference.from_pretrained(".")
55
+
56
+ # Define two product titles
57
+ product_a = "Sony WH-1000XM5 Wireless Noise Canceling Headphones, Black"
58
+ product_b = "Sony WH1000XM5/B Headphones"
59
+
60
+ # Calculate similarity (0 to 1)
61
+ score = model.similarity(product_a, product_b)
62
+
63
+ if score > 0.82:
64
+ print(f"It's a match! (Score: {score:.4f})")
65
+ else:
66
+ print(f"Different products. (Score: {score:.4f})")
67
+ ```
68
+
69
+ ## Automated Sync
70
+
71
+ This repository is automatically synced to Hugging Face Spaces via GitHub Actions.
72
+
73
+
74
+ MIT
demo.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import sys
4
+ import os
5
+ import pandas as pd
6
+ from pathlib import Path
7
+
8
+ # Add root to sys.path
9
+ sys.path.append(os.getcwd())
10
+
11
+ from src.inference import EmbeddingInference
12
+
13
+ st.set_page_config(page_title="Product Model Demo", layout="wide")
14
+
15
+ st.title("Product Model Identity Verification Demo")
16
+ st.markdown("""
17
+ This demo showcases the **Product Model's** ability to verify if two product listings represent the same physical item.
18
+ Key use case: **Matching Scale A (e.g., Your Catalog) vs Site B (e.g., Competitor/Marketplace)**.
19
+ """)
20
+
21
+ # Load Model
22
+ @st.cache_resource
23
+ def load_model():
24
+ model_path = "."
25
+ if not os.path.exists("pytorch_model.bin"):
26
+ return None
27
+ return EmbeddingInference.from_pretrained(model_path)
28
+
29
+ model = load_model()
30
+
31
+ if not model:
32
+ st.error("Model not found in `models/product`. Please train the model first.")
33
+ st.stop()
34
+
35
+ st.success(f"Product Model Loaded! (Vocab: {len(model.tokenizer.word_to_id)}, Dim: {model.model.d_model})")
36
+
37
+ # Settings
38
+ st.sidebar.header("Settings")
39
+ threshold = st.sidebar.slider("Match Threshold", 0.5, 1.0, 0.82, 0.01, help="Score above which products are considered a match.")
40
+
41
+ # Data Input
42
+ st.subheader("1. Input Data")
43
+
44
+ # Session state for text areas (widget keys)
45
+ if 'txt_a' not in st.session_state:
46
+ st.session_state.txt_a = """Apple iPhone 14 128GB Midnight
47
+ Samsung Galaxy S23 Ultra 256GB Black
48
+ Sony WH-1000XM5 Wireless Headphones
49
+ Nintendo Switch OLED White
50
+ Logitech MX Master 3S Performance Mouse
51
+ Nike Air Force 1 '07 White
52
+ Dyson V15 Detect Vacuum"""
53
+
54
+ if 'txt_b' not in st.session_state:
55
+ st.session_state.txt_b = """iPhone 14 (128GB) - Midnight Black
56
+ Samsung S23 Ultra 5G (256GB Storage)
57
+ Sony Noise Cancelling Headphones WH1000XM5
58
+ Nintendo Switch Console - OLED Model
59
+ Logitech Mouse MX Master 3S
60
+ Nike Men's Air Force 1 Sneakers
61
+ Dyson V15 Detect Cordless Vacuum Cleaner
62
+ Apple iPhone 13 128GB
63
+ Samsung Galaxy S22 Ultra
64
+ Sony WH-1000XM4
65
+ Nintendo Switch Lite"""
66
+
67
+ # Button to load large dataset
68
+ if st.button("Load Large Benchmark Dataset (100+ items)"):
69
+ # Generate large dataset
70
+ base_products = [
71
+ ("iPhone 14 Pro 128GB Space Black", "Apple iPhone 14 Pro (128 GB) - Space Black", "iPhone 14 Pro Max 128GB"),
72
+ ("Samsung Galaxy S23 Ultra 512GB", "Samsung S23 Ultra 5G (512GB Storage)", "Samsung Galaxy S23 512GB"),
73
+ ("Sony WH-1000XM5 Headphones", "Sony Noise Cancelling Wireless Headphones WH1000XM5", "Sony WH-1000XM4 Headphones"),
74
+ ("MacBook Air M2 13-inch 256GB", "Apple MacBook Air Laptop: M2 chip, 13.6-inch, 256GB", "MacBook Pro M2 13-inch"),
75
+ ("Dyson V15 Detect Vacuum", "Dyson V15 Detect Cordless Vacuum Cleaner", "Dyson V12 Detect Slim"),
76
+ ("Logitech MX Master 3S", "Logitech Master Series MX 3S Mouse", "Logitech MX Master 3"),
77
+ ("Kindle Paperwhite 16GB", "Amazon Kindle Paperwhite (16 GB) - 6.8 display", "Kindle Paperwhite 8GB"),
78
+ ("Nintendo Switch OLED White", "Nintendo Switch – OLED Model w/ White Joy-Con", "Nintendo Switch Lite Blue"),
79
+ ("PlayStation 5 Console", "Sony PS5 Console Disc Edition", "PlayStation 5 Digital Edition"),
80
+ ("Xbox Series X", "Microsoft Xbox Series X 1TB Console", "Xbox Series S"),
81
+ ("AirPos Pro 2nd Gen", "Apple AirPods Pro (2nd Generation) with MagSafe", "Apple AirPods 3rd Gen"),
82
+ ("Fitbit Charge 6", "Fitbit Charge 6 Fitness Tracker with Google apps", "Fitbit Charge 5"),
83
+ ("Garmin Forerunner 265", "Garmin Forerunner 265 Running Smartwatch", "Garmin Forerunner 965"),
84
+ ("Yeti Rambler 20oz Tumbler", "YETI Rambler 20 oz Stainless Steel Vacuum Insulated", "Yeti Rambler 30oz"),
85
+ ("Stanley Quencher H2.0 40oz", "Stanley The Quencher H2.0 FlowState Tumbler 40oz", "Stanley IceFlow Flip Straw"),
86
+ ("Canon EOS R6 Mark II", "Canon Mirrorless Camera EOS R6 Mark II Body", "Canon EOS R5 Body"),
87
+ ("Nikon Z6 II Body", "Nikon Z 6II FX-Format Mirrorless Camera", "Nikon Z7 II Body"),
88
+ ("DJI Mini 3 Pro", "DJI Mini 3 Pro (DJI RC)", "DJI Mini 2 SE"),
89
+ ("GoPro HERO11 Black", "GoPro HERO11 Black - Waterproof Action Camera", "GoPro HERO10 Black"),
90
+ ("Razer DeathAdder V3 Pro", "Razer DeathAdder V3 Pro Wireless Gaming Mouse", "Razer Viper V2 Pro"),
91
+ ("Keychron K2 Pro Keyboard", "Keychron K2 Pro QMK/VIA Wireless Mechanical Keyboard", "Keychron K2 Version 2"),
92
+ ("Herman Miller Aeron Chair", "Herman Miller Aeron Ergonomic Office Chair", "Herman Miller Mirra 2"),
93
+ ("Instant Pot Duo Plus 6qt", "Instant Pot Duo Plus 9-in-1 Electric Pressure Cooker", "Instant Pot Duo 7-in-1"),
94
+ ("Ninja AF101 Air Fryer", "Ninja AF101 Air Fryer that Crisps, Roasts", "Ninja AF161 Max XL"),
95
+ ("Vitamix 5200 Blender", "Vitamix 5200 Blender Professional-Grade", "Vitamix E310 Explorean"),
96
+ ("Roomba j7+ Vacuum", "iRobot Roomba j7+ (7550) Self-Emptying Robot Vacuum", "Roomba i3+ EVO"),
97
+ ("Sonos Arc Soundbar", "Sonos Arc - The Premium Smart Soundbar", "Sonos Beam Gen 2"),
98
+ ("Bose QuietComfort 45", "Bose QuietComfort 45 Bluetooth Wireless Noise Cancelling", "Bose QuietComfort Earbuds II"),
99
+ ("iPad Air 5th Gen 64GB", "Apple iPad Air (5th Generation): M1 chip, 64GB", "iPad 10th Gen 64GB"),
100
+ ("Samsung T7 Shield 1TB", "Samsung T7 Shield 1TB Portable SSD", "Samsung T7 Touch 1TB"),
101
+ ("SanDisk Extreme 2TB SSD", "SanDisk 2TB Extreme Portable SSD", "SanDisk Extreme Pro 2TB"),
102
+ ("LG C3 OLED TV 65-inch", "LG 65-Inch Class C3 Series OLED evo 4K", "LG B3 OLED TV 65-inch"),
103
+ ("Samsung QN90C 55-inch", "Samsung 55-Inch Class Neo QLED 4K QN90C", "Samsung QN85C 55-inch"),
104
+ ("Google Pixel 7 Pro 128GB", "Google Pixel 7 Pro - 5G Android Phone 128GB", "Google Pixel 7 128GB"),
105
+ ("OnePlus 11 5G 16GB RAM", "OnePlus 11 5G | 16GB RAM+256GB", "OnePlus 10T 5G"),
106
+ ("Asus ROG Zephyrus G14", "ASUS Rogers Zephyrus G14 14” 165Hz Gaming Laptop", "Asus TUF Gaming F15"),
107
+ ("Dell XPS 15 9530", "Dell XPS 15 Laptop, 13th Gen Intel Core", "Dell Inspiron 16 Plus"),
108
+ ("Lenovo ThinkPad X1 Carbon Gen 11", "Lenovo ThinkPad X1 Carbon Gen 11 14 inch", "Lenovo ThinkPad T14s"),
109
+ ("HP Spectre x360 14", "HP Spectre x360 2-in-1 Laptop 13.5t", "HP Envy x360 15"),
110
+ ("Microsoft Surface Pro 9", "Microsoft Surface Pro 9 (2022), 13 2-in-1", "Microsoft Surface Laptop 5"),
111
+ ]
112
+
113
+ a_list = []
114
+ b_list = []
115
+
116
+ # 1. Add core pairs
117
+ for a, b, distractor in base_products:
118
+ a_list.append(a)
119
+ b_list.append(b)
120
+ b_list.append(distractor)
121
+
122
+ # 2. Add algorithmic filler (increase to 35 iterations for >100 total)
123
+ for i in range(35):
124
+ a_list.append(f"Generic Widget Model X-{i+100} Pro")
125
+ b_list.append(f"Generic Widget Series X {i+100} Professional Edition")
126
+ b_list.append(f"Generic Widget Model X-{i+100} Lite") # Distractor
127
+
128
+ a_list.append(f"Industrial Part #44-A{i}")
129
+ b_list.append(f"Genuine Industrial Part Number 44-A{i} Replacement")
130
+ b_list.append(f"Industrial Part #44-B{i}") # Distractor
131
+
132
+ import random
133
+ random.shuffle(b_list)
134
+
135
+ # Update specific keys used by text_area to ensure UI refresh
136
+ st.session_state.txt_a = "\n".join(a_list)
137
+ st.session_state.txt_b = "\n".join(b_list)
138
+
139
+ # Keep backing val updated too
140
+ st.session_state.site_a_val = st.session_state.txt_a
141
+ st.session_state.site_b_val = st.session_state.txt_b
142
+
143
+ st.success(f"Loaded {len(a_list)} items with hard negatives!")
144
+ st.rerun()
145
+
146
+ col1, col2 = st.columns(2)
147
+
148
+ with col1:
149
+ st.markdown("**Site A (Your Catalog)**")
150
+ # Use key to bind to session state
151
+ site_a_text = st.text_area("One product per line", key="txt_a", height=300)
152
+
153
+ with col2:
154
+ st.markdown("**Site B (Competitor/Marketplace)**")
155
+ site_b_text = st.text_area("One product per line", key="txt_b", height=300)
156
+
157
+ # Process
158
+ if st.button("Run Comparison", type="primary"):
159
+ site_a = [x.strip() for x in site_a_text.split('\n') if x.strip()]
160
+ site_b = [x.strip() for x in site_b_text.split('\n') if x.strip()]
161
+
162
+ if not site_a or not site_b:
163
+ st.warning("Please provide data for both sites.")
164
+ st.stop()
165
+
166
+ st.subheader("2. Matching Results")
167
+
168
+ results = []
169
+
170
+ progress_bar = st.progress(0)
171
+
172
+ for i, product_a in enumerate(site_a):
173
+ # Search for best match
174
+ matches = model.search(product_a, site_b, top_k=1)
175
+
176
+ if matches:
177
+ best = matches[0]
178
+ score = best['score']
179
+ match_product = best['text']
180
+ is_match = score >= threshold
181
+
182
+ results.append({
183
+ "Site A Product": product_a,
184
+ "Best Match (Site B)": match_product,
185
+ "Confidence": score,
186
+ "Status": "Match" if is_match else "Different"
187
+ })
188
+ else:
189
+ results.append({
190
+ "Site A Product": product_a,
191
+ "Best Match (Site B)": "No candidate found",
192
+ "Confidence": 0.0,
193
+ "Status": "No Data"
194
+ })
195
+
196
+ progress_bar.progress((i + 1) / len(site_a))
197
+
198
+ df = pd.DataFrame(results)
199
+
200
+ # Sort by Confidence (Desc)
201
+ df = df.sort_values(by="Confidence", ascending=False)
202
+
203
+ # Styling
204
+ def color_status(val):
205
+ color = '#d4edda' if val == "Match" else '#f8d7da'
206
+ return f'background-color: {color}'
207
+
208
+ st.dataframe(
209
+ df.style.applymap(color_status, subset=['Status'])
210
+ .format({"Confidence": "{:.4f}"}),
211
+ use_container_width=True
212
+ )
213
+
214
+ # Stats
215
+ match_count = df[df['Status'] == "Match"].shape[0]
216
+ st.metric("Total Matches Found", f"{match_count} / {len(site_a)}")
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4e423bda187891b565f4928d81f3e0786e4a87a56df5469a79adab4a7a35c05
3
+ size 63975744
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core
2
+ torch>=2.0.0
3
+ numpy>=1.21.0
4
+ tqdm>=4.64.0
5
+
6
+ # Demo UI
7
+ streamlit>=1.30.0
8
+ plotly>=5.0.0
9
+
10
+ # Optional (for clustering, CSV processing, & Benchmarking)
11
+ scikit-learn>=1.0.0
12
+ pandas>=2.0.0
13
+ psutil>=5.9.0
14
+ sentence-transformers>=2.2.0
15
+ safetensors>=0.4.0
src/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MiniEmbed - Lightweight Text Embedding Model
3
+ """
4
+
5
+ from .model import MiniTransformerEmbedding
6
+ from .tokenizer import SimpleTokenizer
7
+ from .inference import EmbeddingInference, EmbeddingModelManager
8
+
9
+ __version__ = "1.0.0"
10
+ __all__ = [
11
+ "MiniTransformerEmbedding",
12
+ "SimpleTokenizer",
13
+ "EmbeddingInference",
14
+ "EmbeddingModelManager"
15
+ ]
src/inference.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Saving & Inference Module
3
+ ===================================
4
+ Easy-to-use API for loading and running inference with the embedding model.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import json
11
+ import os
12
+ from pathlib import Path
13
+ from typing import List, Dict, Union, Tuple
14
+
15
+ from .model import MiniTransformerEmbedding
16
+ from .tokenizer import SimpleTokenizer
17
+
18
+
19
+ class EmbeddingModelManager:
20
+ """
21
+ Handles saving and loading the embedding model.
22
+
23
+ Save structure:
24
+ model_dir/
25
+ ├── config.json # Model architecture config
26
+ ├── model.pt # Model weights
27
+ ├── tokenizer.json # Vocabulary
28
+ └── training_info.json # Training metadata (optional)
29
+ """
30
+
31
+ @staticmethod
32
+ def save_model(
33
+ model: MiniTransformerEmbedding,
34
+ tokenizer: SimpleTokenizer,
35
+ save_dir: str,
36
+ training_info: dict = None
37
+ ):
38
+ """
39
+ Save model, tokenizer, and config for later use.
40
+
41
+ Args:
42
+ model: Trained MiniTransformerEmbedding
43
+ tokenizer: SimpleTokenizer with vocabulary
44
+ save_dir: Directory to save model
45
+ training_info: Optional training metadata
46
+ """
47
+ save_dir = Path(save_dir)
48
+ save_dir.mkdir(parents=True, exist_ok=True)
49
+
50
+ # 1. Save model config
51
+ config = {
52
+ 'vocab_size': len(tokenizer.word_to_id),
53
+ 'd_model': model.d_model,
54
+ 'num_heads': model.layers[0].attention.num_heads,
55
+ 'num_layers': len(model.layers),
56
+ 'd_ff': model.layers[0].feed_forward.linear1.out_features,
57
+ 'max_seq_len': model.positional_encoding.pe.size(1),
58
+ 'pad_token_id': model.pad_token_id,
59
+ 'size_name': save_dir.name # Use folder name as size name
60
+ }
61
+
62
+ with open(save_dir / 'config.json', 'w') as f:
63
+ json.dump(config, f, indent=2)
64
+
65
+ # 2. Save model weights
66
+ torch.save(model.state_dict(), save_dir / 'model.pt')
67
+
68
+ # 3. Save tokenizer vocabulary
69
+ tokenizer.save(str(save_dir / 'tokenizer.json'))
70
+
71
+ # 4. Save training info (optional)
72
+ if training_info:
73
+ with open(save_dir / 'training_info.json', 'w') as f:
74
+ json.dump(training_info, f, indent=2)
75
+
76
+ print(f"Model saved to: {save_dir}")
77
+
78
+ @staticmethod
79
+ def load_model(model_dir: str, device: str = None) -> Tuple[MiniTransformerEmbedding, SimpleTokenizer]:
80
+ """
81
+ Load model and tokenizer from directory.
82
+
83
+ Args:
84
+ model_dir: Directory containing saved model
85
+ device: Device to load model on ('cpu', 'cuda', 'mps')
86
+
87
+ Returns:
88
+ (model, tokenizer) tuple
89
+ """
90
+ model_dir = Path(model_dir)
91
+
92
+ if device is None:
93
+ if torch.cuda.is_available():
94
+ device = 'cuda'
95
+ elif torch.backends.mps.is_available():
96
+ device = 'mps'
97
+ else:
98
+ device = 'cpu'
99
+
100
+ # 1. Load config
101
+ config_path = model_dir / 'config.json'
102
+
103
+ # If loading a checkpoint, the config might be in the 'models' folder instead
104
+ if not config_path.exists() and 'checkpoints' in str(model_dir):
105
+ potential_models_dir = Path(str(model_dir).replace('checkpoints', 'models'))
106
+ if (potential_models_dir / 'config.json').exists():
107
+ config_path = potential_models_dir / 'config.json'
108
+
109
+ if config_path.exists():
110
+ with open(config_path, 'r') as f:
111
+ config = json.load(f)
112
+ else:
113
+ # Fallback for Product Model (Hardcoded defaults)
114
+ print("Warning: config.json not found. Using default Product Model configuration.")
115
+ config = {
116
+ "vocab_size": 50000,
117
+ "d_model": 256,
118
+ "num_heads": 4,
119
+ "num_layers": 4,
120
+ "d_ff": 1024,
121
+ "max_seq_len": 128,
122
+ "pad_token_id": 0
123
+ }
124
+
125
+ # 2. Load tokenizer
126
+ tokenizer_path = model_dir / 'tokenizer.json'
127
+ if not tokenizer_path.exists() and 'checkpoints' in str(model_dir):
128
+ potential_models_dir = Path(str(model_dir).replace('checkpoints', 'models'))
129
+ if (potential_models_dir / 'tokenizer.json').exists():
130
+ tokenizer_path = potential_models_dir / 'tokenizer.json'
131
+
132
+ tokenizer = SimpleTokenizer(vocab_size=config['vocab_size'])
133
+ tokenizer.load(str(tokenizer_path))
134
+
135
+ # 3. Create and load model
136
+ model = MiniTransformerEmbedding(
137
+ vocab_size=config['vocab_size'],
138
+ d_model=config['d_model'],
139
+ num_heads=config['num_heads'],
140
+ num_layers=config['num_layers'],
141
+ d_ff=config['d_ff'],
142
+ max_seq_len=config['max_seq_len'],
143
+ pad_token_id=config.get('pad_token_id', 0)
144
+ )
145
+
146
+ # Try safetensors first (preferred), then pytorch_model.bin, then model.pt
147
+ from safetensors.torch import load_file
148
+
149
+ safe_path = model_dir / 'model.safetensors'
150
+ bin_path = model_dir / 'pytorch_model.bin'
151
+ pt_path = model_dir / 'model.pt'
152
+
153
+ if safe_path.exists():
154
+ state_dict = load_file(safe_path)
155
+ elif bin_path.exists():
156
+ state_dict = torch.load(bin_path, map_location=device, weights_only=True)
157
+ elif pt_path.exists():
158
+ state_dict = torch.load(pt_path, map_location=device, weights_only=True)
159
+ else:
160
+ raise FileNotFoundError(f"No model weights found in {model_dir}")
161
+
162
+ # state_dict loaded, now load into model
163
+ model.load_state_dict(state_dict)
164
+ model = model.to(device)
165
+ model.eval()
166
+
167
+ return model, tokenizer
168
+
169
+ @staticmethod
170
+ def list_models(base_dir: str = "models") -> List[str]:
171
+ """
172
+ List available model names in the base directory.
173
+
174
+ Returns:
175
+ List of directory names containing valid models
176
+ """
177
+ path = Path(base_dir)
178
+ if not path.exists():
179
+ return []
180
+ return sorted([d.name for d in path.iterdir() if d.is_dir() and (d / "model.pt").exists()])
181
+
182
+ class EmbeddingInference:
183
+ """
184
+ High-level inference API for the embedding model.
185
+
186
+ Usage:
187
+ model = EmbeddingInference.from_pretrained("./model")
188
+
189
+ # Encode texts
190
+ embeddings = model.encode(["Hello world", "Machine learning"])
191
+
192
+ # Compute similarity
193
+ score = model.similarity("query", "document")
194
+
195
+ # Semantic search
196
+ results = model.search("python programming", documents)
197
+ """
198
+
199
+ def __init__(
200
+ self,
201
+ model: MiniTransformerEmbedding,
202
+ tokenizer: SimpleTokenizer,
203
+ device: str = 'cpu',
204
+ max_length: int = 64
205
+ ):
206
+ self.model = model
207
+ self.tokenizer = tokenizer
208
+ self.device = device
209
+ self.max_length = max_length
210
+ self.model.eval()
211
+
212
+ @classmethod
213
+ def from_pretrained(cls, model_dir: str, device: str = None, max_length: int = 128):
214
+ """Load from saved model directory."""
215
+ model, tokenizer = EmbeddingModelManager.load_model(model_dir, device)
216
+ if device is None:
217
+ device = next(model.parameters()).device.type
218
+ return cls(model, tokenizer, device, max_length=max_length)
219
+
220
+ def encode(
221
+ self,
222
+ texts: Union[str, List[str]],
223
+ batch_size: int = 32,
224
+ show_progress: bool = False
225
+ ) -> np.ndarray:
226
+ """
227
+ Encode texts to embeddings.
228
+
229
+ Args:
230
+ texts: Single text or list of texts
231
+ batch_size: Batch size for encoding
232
+ show_progress: Show progress bar
233
+
234
+ Returns:
235
+ numpy array of shape (n_texts, d_model)
236
+ """
237
+ if isinstance(texts, str):
238
+ texts = [texts]
239
+
240
+ all_embeddings = []
241
+
242
+ # Process in batches
243
+ for i in range(0, len(texts), batch_size):
244
+ batch_texts = texts[i:i + batch_size]
245
+
246
+ # Tokenize
247
+ encodings = [
248
+ self.tokenizer.encode(t, self.max_length)
249
+ for t in batch_texts
250
+ ]
251
+
252
+ input_ids = torch.stack([e['input_ids'] for e in encodings]).to(self.device)
253
+ attention_mask = torch.stack([e['attention_mask'] for e in encodings]).to(self.device)
254
+
255
+ # Encode
256
+ with torch.no_grad():
257
+ embeddings = self.model.encode(input_ids, attention_mask)
258
+
259
+ all_embeddings.append(embeddings.cpu().numpy())
260
+
261
+ return np.vstack(all_embeddings)
262
+
263
+ def similarity(self, text1: str, text2: str) -> float:
264
+ """Compute cosine similarity between two texts."""
265
+ emb1 = self.encode(text1)
266
+ emb2 = self.encode(text2)
267
+ return float(np.dot(emb1[0], emb2[0]))
268
+
269
+ def pairwise_similarity(self, texts1: List[str], texts2: List[str]) -> np.ndarray:
270
+ """
271
+ Compute pairwise similarity between two lists.
272
+
273
+ Returns:
274
+ Matrix of shape (len(texts1), len(texts2))
275
+ """
276
+ emb1 = self.encode(texts1)
277
+ emb2 = self.encode(texts2)
278
+ return np.dot(emb1, emb2.T)
279
+
280
+ def search(
281
+ self,
282
+ query: str,
283
+ documents: List[str],
284
+ top_k: int = 5
285
+ ) -> List[Dict]:
286
+ """
287
+ Semantic search: Find most similar documents to query.
288
+
289
+ Args:
290
+ query: Search query
291
+ documents: List of documents to search
292
+ top_k: Number of results to return
293
+
294
+ Returns:
295
+ List of dicts with 'text', 'score', 'rank'
296
+ """
297
+ query_emb = self.encode(query)
298
+ doc_embs = self.encode(documents)
299
+
300
+ # Compute similarities
301
+ scores = np.dot(doc_embs, query_emb.T).flatten()
302
+
303
+ # Get top-k indices
304
+ top_indices = np.argsort(scores)[::-1][:top_k]
305
+
306
+ results = []
307
+ for rank, idx in enumerate(top_indices, 1):
308
+ results.append({
309
+ 'rank': rank,
310
+ 'text': documents[idx],
311
+ 'score': float(scores[idx]),
312
+ 'index': int(idx)
313
+ })
314
+
315
+ return results
316
+
317
+ def cluster_texts(self, texts: List[str], n_clusters: int = 5) -> Dict:
318
+ """
319
+ Cluster texts by embedding similarity.
320
+
321
+ Returns:
322
+ Dict with 'labels' and 'texts_by_cluster'
323
+ """
324
+ from sklearn.cluster import KMeans
325
+
326
+ embeddings = self.encode(texts)
327
+
328
+ kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
329
+ labels = kmeans.fit_predict(embeddings)
330
+
331
+ return {
332
+ 'labels': labels.tolist(),
333
+ 'centroids': kmeans.cluster_centers_,
334
+ 'texts_by_cluster': {
335
+ i: [texts[j] for j in range(len(texts)) if labels[j] == i]
336
+ for i in range(n_clusters)
337
+ }
338
+ }
src/model.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mini-Transformer Embedding Model
3
+ ====================================
4
+ A lightweight transformer encoder for generating text embeddings.
5
+ Built from scratch using PyTorch.
6
+
7
+ Architecture:
8
+ - Token Embeddings + Sinusoidal Positional Encoding
9
+ - N Transformer Encoder Layers (Pre-LayerNorm)
10
+ - Multi-Head Self-Attention
11
+ - Position-wise Feed-Forward Networks
12
+ - Mean Pooling + L2 Normalization
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import math
19
+ from typing import Optional
20
+
21
+
22
+ class SinusoidalPositionalEncoding(nn.Module):
23
+ """
24
+ Sinusoidal positional encoding from "Attention Is All You Need".
25
+
26
+ Adds position information to token embeddings using sin/cos functions
27
+ at different frequencies, allowing the model to understand token order.
28
+ """
29
+
30
+ def __init__(self, d_model: int, max_seq_len: int = 512, dropout: float = 0.1):
31
+ super().__init__()
32
+ self.dropout = nn.Dropout(p=dropout)
33
+
34
+ # Create positional encoding matrix [max_seq_len, d_model]
35
+ pe = torch.zeros(max_seq_len, d_model)
36
+ position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
37
+
38
+ # Compute division term for frequencies
39
+ div_term = torch.exp(
40
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
41
+ )
42
+
43
+ # Apply sin to even indices, cos to odd indices
44
+ pe[:, 0::2] = torch.sin(position * div_term)
45
+ pe[:, 1::2] = torch.cos(position * div_term)
46
+
47
+ # Add batch dimension and register as buffer (not a parameter)
48
+ pe = pe.unsqueeze(0) # [1, max_seq_len, d_model]
49
+ self.register_buffer('pe', pe)
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ """
53
+ Args:
54
+ x: Tensor of shape [batch_size, seq_len, d_model]
55
+ Returns:
56
+ Tensor with positional encoding added
57
+ """
58
+ x = x + self.pe[:, :x.size(1), :]
59
+ return self.dropout(x)
60
+
61
+
62
+ class MultiHeadSelfAttention(nn.Module):
63
+ """
64
+ Multi-Head Self-Attention mechanism.
65
+
66
+ Allows the model to jointly attend to information from different
67
+ representation subspaces at different positions.
68
+ """
69
+
70
+ def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
71
+ super().__init__()
72
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
73
+
74
+ self.d_model = d_model
75
+ self.num_heads = num_heads
76
+ self.d_k = d_model // num_heads # Dimension per head
77
+
78
+ # Linear projections for Q, K, V
79
+ self.W_q = nn.Linear(d_model, d_model)
80
+ self.W_k = nn.Linear(d_model, d_model)
81
+ self.W_v = nn.Linear(d_model, d_model)
82
+
83
+ # Output projection
84
+ self.W_o = nn.Linear(d_model, d_model)
85
+
86
+ self.dropout = nn.Dropout(dropout)
87
+ self.scale = math.sqrt(self.d_k)
88
+
89
+ def forward(
90
+ self,
91
+ x: torch.Tensor,
92
+ attention_mask: Optional[torch.Tensor] = None
93
+ ) -> torch.Tensor:
94
+ """
95
+ Args:
96
+ x: Input tensor [batch_size, seq_len, d_model]
97
+ attention_mask: Optional mask [batch_size, seq_len]
98
+ Returns:
99
+ Output tensor [batch_size, seq_len, d_model]
100
+ """
101
+ batch_size, seq_len, _ = x.size()
102
+
103
+ # Linear projections
104
+ Q = self.W_q(x) # [batch, seq, d_model]
105
+ K = self.W_k(x)
106
+ V = self.W_v(x)
107
+
108
+ # Reshape to [batch, num_heads, seq, d_k]
109
+ Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
110
+ K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
111
+ V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
112
+
113
+ # Scaled dot-product attention
114
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
115
+ # scores: [batch, num_heads, seq, seq]
116
+
117
+ # Apply attention mask (for padding)
118
+ if attention_mask is not None:
119
+ # Expand mask: [batch, 1, 1, seq]
120
+ mask = attention_mask.unsqueeze(1).unsqueeze(2)
121
+ scores = scores.masked_fill(mask == 0, float('-inf'))
122
+
123
+ # Softmax and dropout
124
+ attn_weights = F.softmax(scores, dim=-1)
125
+ attn_weights = self.dropout(attn_weights)
126
+
127
+ # Apply attention to values
128
+ context = torch.matmul(attn_weights, V)
129
+ # context: [batch, num_heads, seq, d_k]
130
+
131
+ # Reshape back: [batch, seq, d_model]
132
+ context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
133
+
134
+ # Output projection
135
+ output = self.W_o(context)
136
+
137
+ return output
138
+
139
+
140
+ class PositionwiseFeedForward(nn.Module):
141
+ """
142
+ Position-wise Feed-Forward Network.
143
+
144
+ Two linear transformations with a GELU activation in between.
145
+ Applied to each position separately and identically.
146
+ """
147
+
148
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
149
+ super().__init__()
150
+ self.linear1 = nn.Linear(d_model, d_ff)
151
+ self.linear2 = nn.Linear(d_ff, d_model)
152
+ self.dropout = nn.Dropout(dropout)
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ """
156
+ Args:
157
+ x: Input tensor [batch_size, seq_len, d_model]
158
+ Returns:
159
+ Output tensor [batch_size, seq_len, d_model]
160
+ """
161
+ x = self.linear1(x)
162
+ x = F.gelu(x)
163
+ x = self.dropout(x)
164
+ x = self.linear2(x)
165
+ return x
166
+
167
+
168
+ class TransformerEncoderLayer(nn.Module):
169
+ """
170
+ Single Transformer Encoder Layer with Pre-LayerNorm.
171
+
172
+ Components:
173
+ 1. Multi-Head Self-Attention with residual connection
174
+ 2. Position-wise Feed-Forward with residual connection
175
+
176
+ Uses Pre-LayerNorm for better training stability.
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ d_model: int,
182
+ num_heads: int,
183
+ d_ff: int,
184
+ dropout: float = 0.1
185
+ ):
186
+ super().__init__()
187
+
188
+ # Layer normalization
189
+ self.norm1 = nn.LayerNorm(d_model)
190
+ self.norm2 = nn.LayerNorm(d_model)
191
+
192
+ # Sub-layers
193
+ self.attention = MultiHeadSelfAttention(d_model, num_heads, dropout)
194
+ self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
195
+
196
+ # Dropout for residual connections
197
+ self.dropout = nn.Dropout(dropout)
198
+
199
+ def forward(
200
+ self,
201
+ x: torch.Tensor,
202
+ attention_mask: Optional[torch.Tensor] = None
203
+ ) -> torch.Tensor:
204
+ """
205
+ Args:
206
+ x: Input tensor [batch_size, seq_len, d_model]
207
+ attention_mask: Optional mask [batch_size, seq_len]
208
+ Returns:
209
+ Output tensor [batch_size, seq_len, d_model]
210
+ """
211
+ # Pre-norm attention block
212
+ normed = self.norm1(x)
213
+ attn_output = self.attention(normed, attention_mask)
214
+ x = x + self.dropout(attn_output) # Residual connection
215
+
216
+ # Pre-norm feed-forward block
217
+ normed = self.norm2(x)
218
+ ff_output = self.feed_forward(normed)
219
+ x = x + self.dropout(ff_output) # Residual connection
220
+
221
+ return x
222
+
223
+
224
+ class MiniTransformerEmbedding(nn.Module):
225
+ """
226
+ Mini-Transformer Embedding Model.
227
+
228
+ Converts variable-length text sequences into fixed-size dense vectors
229
+ suitable for semantic similarity, search, and clustering tasks.
230
+
231
+ Architecture:
232
+ 1. Token Embedding Layer (vocab → d_model)
233
+ 2. Sinusoidal Positional Encoding
234
+ 3. N Transformer Encoder Layers
235
+ 4. Mean Pooling (sequence → single vector)
236
+ 5. L2 Normalization (for cosine similarity)
237
+ """
238
+
239
+ def __init__(
240
+ self,
241
+ vocab_size: int = 30000,
242
+ d_model: int = 256,
243
+ num_heads: int = 4,
244
+ num_layers: int = 4,
245
+ d_ff: int = 1024,
246
+ max_seq_len: int = 128,
247
+ dropout: float = 0.1,
248
+ pad_token_id: int = 0
249
+ ):
250
+ super().__init__()
251
+
252
+ self.d_model = d_model
253
+ self.pad_token_id = pad_token_id
254
+
255
+ # Token embedding
256
+ self.token_embedding = nn.Embedding(
257
+ vocab_size, d_model, padding_idx=pad_token_id
258
+ )
259
+
260
+ # Positional encoding
261
+ self.positional_encoding = SinusoidalPositionalEncoding(
262
+ d_model, max_seq_len, dropout
263
+ )
264
+
265
+ # Transformer encoder layers
266
+ self.layers = nn.ModuleList([
267
+ TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
268
+ for _ in range(num_layers)
269
+ ])
270
+
271
+ # Final layer norm
272
+ self.final_norm = nn.LayerNorm(d_model)
273
+
274
+ # Initialize weights
275
+ self._init_weights()
276
+
277
+ def _init_weights(self):
278
+ """Initialize weights using Xavier/Glorot initialization."""
279
+ for module in self.modules():
280
+ if isinstance(module, nn.Linear):
281
+ nn.init.xavier_uniform_(module.weight)
282
+ if module.bias is not None:
283
+ nn.init.zeros_(module.bias)
284
+ elif isinstance(module, nn.Embedding):
285
+ nn.init.normal_(module.weight, mean=0, std=0.02)
286
+ if module.padding_idx is not None:
287
+ nn.init.zeros_(module.weight[module.padding_idx])
288
+
289
+ def forward(
290
+ self,
291
+ input_ids: torch.Tensor,
292
+ attention_mask: Optional[torch.Tensor] = None
293
+ ) -> torch.Tensor:
294
+ """
295
+ Forward pass through the encoder.
296
+
297
+ Args:
298
+ input_ids: Token IDs [batch_size, seq_len]
299
+ attention_mask: Mask for padding [batch_size, seq_len]
300
+
301
+ Returns:
302
+ Token-level representations [batch_size, seq_len, d_model]
303
+ """
304
+ # Token embeddings with scaling
305
+ x = self.token_embedding(input_ids) * math.sqrt(self.d_model)
306
+
307
+ # Add positional encoding
308
+ x = self.positional_encoding(x)
309
+
310
+ # Pass through transformer layers
311
+ for layer in self.layers:
312
+ x = layer(x, attention_mask)
313
+
314
+ # Final layer norm
315
+ x = self.final_norm(x)
316
+
317
+ return x
318
+
319
+ def encode(
320
+ self,
321
+ input_ids: torch.Tensor,
322
+ attention_mask: Optional[torch.Tensor] = None
323
+ ) -> torch.Tensor:
324
+ """
325
+ Encode input tokens to a single embedding vector per sequence.
326
+
327
+ Uses mean pooling over non-padded tokens, followed by L2 normalization.
328
+
329
+ Args:
330
+ input_ids: Token IDs [batch_size, seq_len]
331
+ attention_mask: Mask for padding [batch_size, seq_len]
332
+
333
+ Returns:
334
+ Normalized embeddings [batch_size, d_model]
335
+ """
336
+ # Get token-level representations
337
+ token_embeddings = self.forward(input_ids, attention_mask)
338
+
339
+ # Mean pooling
340
+ if attention_mask is not None:
341
+ # Expand mask for broadcasting: [batch, seq, 1]
342
+ mask_expanded = attention_mask.unsqueeze(-1).float()
343
+
344
+ # Sum of embeddings (masked)
345
+ sum_embeddings = torch.sum(token_embeddings * mask_expanded, dim=1)
346
+
347
+ # Count of non-padded tokens
348
+ sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
349
+
350
+ # Mean
351
+ embeddings = sum_embeddings / sum_mask
352
+ else:
353
+ # Simple mean over all tokens
354
+ embeddings = torch.mean(token_embeddings, dim=1)
355
+
356
+ # L2 normalization for cosine similarity
357
+ embeddings = F.normalize(embeddings, p=2, dim=1)
358
+
359
+ return embeddings
src/tokenizer.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple Word-Level Tokenizer
3
+ ==============================
4
+ A basic tokenizer for demonstration purposes.
5
+ Converts text to token IDs with special tokens.
6
+ """
7
+
8
+ import re
9
+ import json
10
+ import torch
11
+ from typing import Dict, List, Optional
12
+ from collections import Counter
13
+ from tqdm import tqdm
14
+
15
+
16
+ class SimpleTokenizer:
17
+ """
18
+ A simple word-level tokenizer with special tokens.
19
+
20
+ Special Tokens:
21
+ - [PAD]: Padding token (id=0)
22
+ - [UNK]: Unknown token (id=1)
23
+ - [CLS]: Classification token (id=2)
24
+ - [SEP]: Separator token (id=3)
25
+ """
26
+
27
+ def __init__(self, vocab_size: int = 30000):
28
+ self.vocab_size = vocab_size
29
+
30
+ # Special tokens
31
+ self.special_tokens = {
32
+ '[PAD]': 0,
33
+ '[UNK]': 1,
34
+ '[CLS]': 2,
35
+ '[SEP]': 3,
36
+ }
37
+
38
+ # Word to ID mapping
39
+ self.word_to_id: Dict[str, int] = dict(self.special_tokens)
40
+ self.id_to_word: Dict[int, str] = {v: k for k, v in self.special_tokens.items()}
41
+
42
+ # Special token IDs
43
+ self.pad_token_id = 0
44
+ self.unk_token_id = 1
45
+ self.cls_token_id = 2
46
+ self.sep_token_id = 3
47
+
48
+ def _tokenize(self, text: str) -> List[str]:
49
+ """
50
+ Split text into tokens (simple word-level tokenization).
51
+
52
+ Args:
53
+ text: Input text string
54
+
55
+ Returns:
56
+ List of tokens
57
+ """
58
+ # Lowercase and basic cleaning
59
+ text = text.lower().strip()
60
+
61
+ # Simple word tokenization with punctuation handling
62
+ tokens = re.findall(r'\b\w+\b|[^\w\s]', text)
63
+
64
+ return tokens
65
+
66
+ def build_vocab(self, texts: List[str], min_freq: int = 2):
67
+ """
68
+ Build vocabulary from a list of texts.
69
+
70
+ Args:
71
+ texts: List of text strings
72
+ min_freq: Minimum frequency for a word to be included
73
+ """
74
+ # Count word frequencies
75
+ word_counts = Counter()
76
+
77
+ for text in tqdm(texts, desc="Building vocabulary"):
78
+ tokens = self._tokenize(text)
79
+ word_counts.update(tokens)
80
+
81
+ # Sort by frequency and take top vocab_size - special_tokens
82
+ max_words = self.vocab_size - len(self.special_tokens)
83
+
84
+ sorted_words = sorted(
85
+ word_counts.items(),
86
+ key=lambda x: x[1],
87
+ reverse=True
88
+ )
89
+
90
+ # Add words to vocabulary
91
+ for word, count in sorted_words[:max_words]:
92
+ if count >= min_freq and word not in self.word_to_id:
93
+ idx = len(self.word_to_id)
94
+ self.word_to_id[word] = idx
95
+ self.id_to_word[idx] = word
96
+
97
+ print(f"Vocabulary size: {len(self.word_to_id)}")
98
+
99
+ def encode(self, text: str, max_length: int = 128) -> Dict:
100
+ # Tokenize
101
+ tokens = self._tokenize(text)
102
+
103
+ # Convert to IDs (with CLS and SEP)
104
+ token_ids = [self.cls_token_id]
105
+
106
+ for token in tokens[:max_length - 2]: # Reserve space for CLS and SEP
107
+ token_id = self.word_to_id.get(token, self.unk_token_id)
108
+ token_ids.append(token_id)
109
+
110
+ token_ids.append(self.sep_token_id)
111
+
112
+ # Create attention mask
113
+ attention_mask = [1] * len(token_ids)
114
+
115
+ # Pad to max_length
116
+ padding_length = max_length - len(token_ids)
117
+ token_ids.extend([self.pad_token_id] * padding_length)
118
+ attention_mask.extend([0] * padding_length)
119
+
120
+ return {
121
+ 'input_ids': torch.tensor(token_ids, dtype=torch.long),
122
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.long)
123
+ }
124
+
125
+ def decode(self, token_ids: List[int]) -> str:
126
+ """
127
+ Decode token IDs back to text.
128
+
129
+ Args:
130
+ token_ids: List of token IDs
131
+
132
+ Returns:
133
+ Decoded text string
134
+ """
135
+ tokens = []
136
+ for idx in token_ids:
137
+ if idx in [self.pad_token_id, self.cls_token_id, self.sep_token_id]:
138
+ continue
139
+ token = self.id_to_word.get(idx, '[UNK]')
140
+ tokens.append(token)
141
+ return ' '.join(tokens)
142
+
143
+ def save(self, path: str):
144
+ """Save tokenizer vocabulary to JSON file."""
145
+ data = {
146
+ 'vocab_size': self.vocab_size,
147
+ 'word_to_id': self.word_to_id,
148
+ }
149
+ with open(path, 'w') as f:
150
+ json.dump(data, f, indent=2)
151
+
152
+ def load(self, path: str):
153
+ """Load tokenizer vocabulary from JSON file."""
154
+ with open(path, 'r') as f:
155
+ data = json.load(f)
156
+
157
+ self.vocab_size = data['vocab_size']
158
+ self.word_to_id = data['word_to_id']
159
+ self.id_to_word = {int(v): k for k, v in self.word_to_id.items()}
160
+
161
+ def __len__(self) -> int:
162
+ return len(self.word_to_id)
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff