Feature Extraction
Transformers
Safetensors
new
custom_code
text-embeddings-inference
adambuttrick commited on
Commit
ff2449f
·
verified ·
1 Parent(s): e1edb4a

Add script for generating embeddings

Browse files
Files changed (1) hide show
  1. generate_embeddings.py +338 -0
generate_embeddings.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import List, Dict, Any, Optional
7
+ import warnings
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import pandas as pd
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+ from datasets import Dataset, DatasetDict
15
+ from transformers import AutoModel, AutoTokenizer
16
+
17
+ warnings.filterwarnings('ignore')
18
+
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(levelname)s - %(message)s',
22
+ handlers=[
23
+ logging.FileHandler('embedding_generation.log'),
24
+ logging.StreamHandler()
25
+ ]
26
+ )
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class AffiliationEmbedder:
31
+ def __init__(
32
+ self,
33
+ model_path: str = "./affiliation-clustering-0.3b",
34
+ device: str = None,
35
+ batch_size: int = 32,
36
+ max_length: int = 512,
37
+ use_fp16: bool = False
38
+ ):
39
+ self.model_path = model_path
40
+ self.batch_size = batch_size
41
+ self.max_length = max_length
42
+ self.use_fp16 = use_fp16
43
+
44
+ if device is None:
45
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
46
+ else:
47
+ self.device = torch.device(device)
48
+
49
+ logger.info(f"Using device: {self.device}")
50
+ if self.device.type == 'cuda':
51
+ logger.info(f"GPU: {torch.cuda.get_device_name()}")
52
+ logger.info(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
53
+
54
+ self._load_model()
55
+
56
+ def _load_model(self):
57
+ logger.info(f"Loading model from {self.model_path}")
58
+
59
+ try:
60
+ self.tokenizer = AutoTokenizer.from_pretrained(
61
+ self.model_path,
62
+ trust_remote_code=True
63
+ )
64
+
65
+ self.model = AutoModel.from_pretrained(
66
+ self.model_path,
67
+ trust_remote_code=True
68
+ )
69
+
70
+ self.model = self.model.to(self.device)
71
+
72
+ if self.use_fp16 and self.device.type == 'cuda':
73
+ self.model = self.model.half()
74
+ logger.info("Using FP16 mixed precision")
75
+
76
+ self.model.eval()
77
+
78
+ logger.info("Model loaded successfully")
79
+
80
+ except Exception as e:
81
+ logger.error(f"Failed to load model: {e}")
82
+ raise
83
+
84
+ def encode_batch(self, texts: List[str]) -> np.ndarray:
85
+ encoded = self.tokenizer(
86
+ texts,
87
+ padding=True,
88
+ truncation=True,
89
+ max_length=self.max_length,
90
+ return_tensors='pt'
91
+ )
92
+
93
+ encoded = {k: v.to(self.device) for k, v in encoded.items()}
94
+
95
+ with torch.no_grad():
96
+ outputs = self.model(**encoded)
97
+
98
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
99
+ embeddings = outputs.pooler_output
100
+ else:
101
+ token_embeddings = outputs.last_hidden_state
102
+ attention_mask = encoded['attention_mask'].unsqueeze(-1)
103
+ masked_embeddings = token_embeddings * attention_mask
104
+ embeddings = masked_embeddings.sum(dim=1) / attention_mask.sum(dim=1)
105
+
106
+ embeddings = F.normalize(embeddings, p=2, dim=1)
107
+
108
+ embeddings = embeddings.cpu().numpy()
109
+
110
+ if self.use_fp16:
111
+ embeddings = embeddings.astype(np.float32)
112
+
113
+ return embeddings
114
+
115
+ def process_dataset(
116
+ self,
117
+ data_path: str,
118
+ output_path: str,
119
+ checkpoint_interval: int = 1000
120
+ ) -> None:
121
+
122
+ logger.info(f"Processing dataset: {data_path}")
123
+
124
+ df = pd.read_parquet(data_path)
125
+ logger.info(f"Loaded {len(df)} samples")
126
+
127
+ checkpoint_path = output_path.replace('.parquet', '_checkpoint.parquet')
128
+ start_idx = 0
129
+
130
+ if os.path.exists(checkpoint_path):
131
+ logger.info(f"Found checkpoint at {checkpoint_path}")
132
+ checkpoint_df = pd.read_parquet(checkpoint_path)
133
+ start_idx = len(checkpoint_df)
134
+ logger.info(f"Resuming from index {start_idx}")
135
+
136
+ all_embeddings = []
137
+ processed_rows = []
138
+
139
+ total_batches = (len(df) - start_idx + self.batch_size - 1) // self.batch_size
140
+
141
+ with tqdm(total=total_batches, desc="Generating embeddings") as pbar:
142
+ for i in range(start_idx, len(df), self.batch_size):
143
+ batch_df = df.iloc[i:i+self.batch_size]
144
+ texts = batch_df['affiliation_name'].tolist()
145
+
146
+ try:
147
+ batch_embeddings = self.encode_batch(texts)
148
+
149
+ for j, embedding in enumerate(batch_embeddings):
150
+ row_idx = i + j
151
+ row_data = df.iloc[row_idx].to_dict()
152
+ row_data['embedding'] = embedding
153
+ processed_rows.append(row_data)
154
+
155
+ if len(processed_rows) % checkpoint_interval == 0:
156
+ self._save_checkpoint(processed_rows, checkpoint_path)
157
+ logger.info(f"Checkpoint saved at {len(processed_rows)} samples")
158
+
159
+ pbar.update(1)
160
+
161
+ except Exception as e:
162
+ logger.error(f"Error processing batch at index {i}: {e}")
163
+ if processed_rows:
164
+ self._save_checkpoint(processed_rows, checkpoint_path)
165
+ raise
166
+
167
+ result_df = pd.DataFrame(processed_rows)
168
+
169
+ logger.info(f"Saving embeddings to {output_path}")
170
+ result_df.to_parquet(output_path, compression='snappy')
171
+
172
+ if os.path.exists(checkpoint_path):
173
+ os.remove(checkpoint_path)
174
+ logger.info("Checkpoint file removed")
175
+
176
+ logger.info(f"Successfully generated embeddings for {len(result_df)} samples")
177
+
178
+ embedding_dim = len(result_df['embedding'].iloc[0])
179
+ logger.info(f"Embedding dimension: {embedding_dim}")
180
+ logger.info(f"Output file size: {os.path.getsize(output_path) / 1e6:.2f} MB")
181
+
182
+ def _save_checkpoint(self, processed_rows: List[Dict], checkpoint_path: str):
183
+ checkpoint_df = pd.DataFrame(processed_rows)
184
+ checkpoint_df.to_parquet(checkpoint_path, compression='snappy')
185
+
186
+
187
+ def main():
188
+ parser = argparse.ArgumentParser(
189
+ description="Generate embeddings for affiliation strings"
190
+ )
191
+ parser.add_argument(
192
+ "--model-path",
193
+ type=str,
194
+ default="./affiliation-clustering-0.3b",
195
+ help="Path to the pre-trained model directory"
196
+ )
197
+ parser.add_argument(
198
+ "--data-dir",
199
+ type=str,
200
+ default="./20250727-unique-openalex-affiliations-w-ror-ids-top-1K-ror-ids-100-per-sample",
201
+ help="Directory containing the input parquet files"
202
+ )
203
+ parser.add_argument(
204
+ "--output-dir",
205
+ type=str,
206
+ default="./20250727-unique-openalex-affiliations-w-ror-ids-top-1K-ror-ids-100-per-sample-embeddings",
207
+ help="Directory to save the output embeddings"
208
+ )
209
+ parser.add_argument(
210
+ "--batch-size",
211
+ type=int,
212
+ default=32,
213
+ help="Batch size for processing"
214
+ )
215
+ parser.add_argument(
216
+ "--max-length",
217
+ type=int,
218
+ default=512,
219
+ help="Maximum sequence length for tokenization"
220
+ )
221
+ parser.add_argument(
222
+ "--device",
223
+ type=str,
224
+ default=None,
225
+ help="Device to use (cuda/cpu, auto-detect if not specified)"
226
+ )
227
+ parser.add_argument(
228
+ "--use-fp16",
229
+ action="store_true",
230
+ help="Use FP16 mixed precision for faster processing"
231
+ )
232
+ parser.add_argument(
233
+ "--checkpoint-interval",
234
+ type=int,
235
+ default=1000,
236
+ help="Save checkpoint every N batches"
237
+ )
238
+ parser.add_argument(
239
+ "--push-to-hub",
240
+ action="store_true",
241
+ help="Push the resulting dataset to Hugging Face Hub"
242
+ )
243
+ parser.add_argument(
244
+ "--hub-dataset-id",
245
+ type=str,
246
+ default=None,
247
+ help="Hugging Face Hub dataset ID (required if push-to-hub is set)"
248
+ )
249
+
250
+ args = parser.parse_args()
251
+
252
+ output_dir = Path(args.output_dir)
253
+ output_dir.mkdir(parents=True, exist_ok=True)
254
+
255
+ embedder = AffiliationEmbedder(
256
+ model_path=args.model_path,
257
+ device=args.device,
258
+ batch_size=args.batch_size,
259
+ max_length=args.max_length,
260
+ use_fp16=args.use_fp16
261
+ )
262
+
263
+ data_dir = Path(args.data_dir)
264
+ train_file = list(data_dir.glob("*_train.parquet"))[0]
265
+ test_file = list(data_dir.glob("*_test.parquet"))[0]
266
+
267
+ train_output = output_dir / "train_embeddings.parquet"
268
+ test_output = output_dir / "test_embeddings.parquet"
269
+
270
+ logger.info("Processing training dataset...")
271
+ embedder.process_dataset(
272
+ str(train_file),
273
+ str(train_output),
274
+ checkpoint_interval=args.checkpoint_interval
275
+ )
276
+
277
+ logger.info("Processing test dataset...")
278
+ embedder.process_dataset(
279
+ str(test_file),
280
+ str(test_output),
281
+ checkpoint_interval=args.checkpoint_interval
282
+ )
283
+
284
+ if args.push_to_hub:
285
+ if not args.hub_dataset_id:
286
+ logger.error("--hub-dataset-id is required when --push-to-hub is set")
287
+ sys.exit(1)
288
+
289
+ logger.info(f"Pushing dataset to Hugging Face Hub: {args.hub_dataset_id}")
290
+
291
+ try:
292
+ from huggingface_hub import HfApi, login
293
+
294
+ token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN')
295
+ if token:
296
+ login(token=token)
297
+ logger.info("Authenticated with Hugging Face Hub using token")
298
+ else:
299
+ logger.info("No HF token found in environment, attempting to use existing credentials")
300
+
301
+ logger.info("Loading generated embeddings...")
302
+ train_df = pd.read_parquet(train_output)
303
+ test_df = pd.read_parquet(test_output)
304
+
305
+ logger.info(f"Train dataset: {len(train_df)} samples")
306
+ logger.info(f"Test dataset: {len(test_df)} samples")
307
+
308
+ logger.info("Creating dataset dictionary...")
309
+ dataset_dict = DatasetDict({
310
+ 'train': Dataset.from_pandas(train_df),
311
+ 'test': Dataset.from_pandas(test_df)
312
+ })
313
+
314
+ logger.info(f"Pushing to hub: {args.hub_dataset_id}")
315
+ dataset_dict.push_to_hub(
316
+ args.hub_dataset_id,
317
+ private=False,
318
+ commit_message="Add affiliation embeddings generated with affiliation-clustering-0.3b model"
319
+ )
320
+ logger.info(f"Dataset successfully pushed to {args.hub_dataset_id}")
321
+ logger.info(f"View at: https://huggingface.co/datasets/{args.hub_dataset_id}")
322
+
323
+ except ImportError as e:
324
+ logger.error(f"Failed to import required libraries: {e}")
325
+ logger.error("Make sure huggingface_hub and datasets are installed")
326
+ sys.exit(1)
327
+ except Exception as e:
328
+ logger.error(f"Failed to push dataset to hub: {e}")
329
+ logger.error(f"Error type: {type(e).__name__}")
330
+ import traceback
331
+ logger.error(f"Traceback: {traceback.format_exc()}")
332
+ sys.exit(1)
333
+
334
+ logger.info("Embedding generation completed successfully!")
335
+
336
+
337
+ if __name__ == "__main__":
338
+ main()