| |
| """ |
| Regenerate answer embeddings using the MuRIL model. |
| This script: |
| - downloads model (if MODEL_DIR is a repo id), |
| - reads CSV at CSV_PATH, |
| - computes mean-pooled, L2-normalized embeddings for 'answer' column, |
| - saves embeddings to OUT_EMBED_PATH. |
| |
| Exit codes: |
| - 0 on success |
| - non-zero on failure |
| """ |
| import os, argparse, math, sys |
| from pathlib import Path |
| import torch |
| import pandas as pd |
| from tqdm.auto import tqdm |
| from transformers import AutoTokenizer, AutoModel |
| from huggingface_hub import snapshot_download |
|
|
| def mean_pooling(last_hidden_state, attention_mask): |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() |
| sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1) |
| sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
| return sum_embeddings / sum_mask |
|
|
| def parse_env(): |
| |
| cfg = {} |
| cfg['model_dir'] = os.getenv("MODEL_DIR", os.getenv("HF_REPO", "Sp2503/Finetuned-multilingualdataset-MuriL-model")) |
| cfg['csv_path'] = os.getenv("CSV_PATH", "/app/export_artifacts/muril_multilingual_dataset.csv") |
| cfg['out_path'] = os.getenv("OUT_EMBED_PATH", "/app/export_artifacts/answer_embeddings.pt") |
| cfg['batch_size'] = int(os.getenv("EMBED_BATCH_SIZE", "64")) |
| cfg['device'] = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu") |
| cfg['download_cache'] = os.getenv("HF_CACHE_DIR", "/tmp/hf_cache") |
| cfg['upload_back'] = os.getenv("UPLOAD_BACK", "false").lower() in ("1","true","yes") |
| cfg['hf_repo'] = os.getenv("HF_REPO", None) |
| return cfg |
|
|
| def main(): |
| cfg = parse_env() |
| print("Regenerate embeddings with config:", cfg) |
| model_dir = cfg['model_dir'] |
| |
| if "/" in model_dir and not os.path.isdir(model_dir): |
| print("Detected HF repo id for model. snapshot_download ->", cfg['download_cache']) |
| try: |
| model_dir = snapshot_download(repo_id=cfg['model_dir'], repo_type="model", cache_dir=cfg['download_cache']) |
| print("Downloaded model to:", model_dir) |
| except Exception as e: |
| print("Failed to snapshot_download model:", e, file=sys.stderr) |
| sys.exit(2) |
|
|
| csv_path = cfg['csv_path'] |
| out_path = cfg['out_path'] |
| batch_size = cfg['batch_size'] |
| device = cfg['device'] |
| print(f"Loading CSV: {csv_path}") |
| if not os.path.isfile(csv_path): |
| print(f"CSV not found at {csv_path}", file=sys.stderr) |
| sys.exit(3) |
| df = pd.read_csv(csv_path, dtype=str).fillna("") |
| if 'answer' not in df.columns: |
| print("CSV must contain 'answer' column", file=sys.stderr) |
| sys.exit(4) |
| answers = df['answer'].astype(str).tolist() |
| print(f"Encoding {len(answers)} answers on device {device} (batch_size={batch_size})") |
|
|
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True) |
| model = AutoModel.from_pretrained(model_dir) |
| model.to(device) |
| model.eval() |
| except Exception as e: |
| print("Failed to load model/tokenizer:", e, file=sys.stderr) |
| sys.exit(5) |
|
|
| |
| all_embs = [] |
| try: |
| with torch.inference_mode(): |
| for i in tqdm(range(0, len(answers), batch_size), desc="Batches"): |
| batch = answers[i:i+batch_size] |
| enc = tokenizer(batch, padding=True, truncation=True, max_length=256, return_tensors="pt") |
| input_ids = enc["input_ids"].to(device) |
| attention_mask = enc["attention_mask"].to(device) |
| out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) |
| pooled = mean_pooling(out.last_hidden_state, attention_mask) |
| pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) |
| all_embs.append(pooled.cpu()) |
| except Exception as e: |
| print("Error during encoding:", e, file=sys.stderr) |
| sys.exit(6) |
|
|
| all_embs = torch.cat(all_embs, dim=0) |
| print("Final embeddings shape:", all_embs.shape) |
| Path(out_path).parent.mkdir(parents=True, exist_ok=True) |
| torch.save(all_embs, out_path) |
| print("Saved embeddings to:", out_path) |
|
|
| |
| if cfg['upload_back'] and cfg['hf_repo']: |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| print(f"Uploading {out_path} back to repo {cfg['hf_repo']} ...") |
| api.upload_file( |
| path_or_fileobj=out_path, |
| path_in_repo=os.path.basename(out_path), |
| repo_id=cfg['hf_repo'], |
| repo_type="model", |
| ) |
| print("Upload complete.") |
| except Exception as e: |
| print("Upload back failed:", e, file=sys.stderr) |
|
|
| |
| norms = (all_embs * all_embs).sum(dim=1) |
| print("Sample norms (should be ~1.0):", norms[:5].tolist()) |
| return 0 |
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|