| | |
| | """ |
| | Produce answer embeddings for the dataset using a fine-tuned MuRIL model. |
| | Saves: |
| | - muril_multilingual_dataset.csv (columns: question, answer, language) |
| | - answer_embeddings.pt (torch tensor shape [N, D], float32, on CPU) |
| | |
| | Usage: |
| | python embed_build_muril.py \ |
| | --model_dir ./muril_multilang_out \ |
| | --input_jsonl /path/to/legal_multilingual_QA_10k.jsonl \ |
| | --out_dir ./export_artifacts \ |
| | --batch_size 64 |
| | """ |
| | import argparse, os, math |
| | from pathlib import Path |
| |
|
| | import torch |
| | import pandas as pd |
| | from tqdm.auto import tqdm |
| | from transformers import AutoTokenizer, AutoModel |
| |
|
| | def parse_args(): |
| | p = argparse.ArgumentParser() |
| | p.add_argument("--model_dir", type=str, default="./muril_multilang_out", help="Path or HF repo id of fine-tuned MuRIL") |
| | p.add_argument("--input_jsonl", type=str, required=True, help="Path to legal_multilingual_QA_10k.jsonl") |
| | p.add_argument("--out_dir", type=str, default="./export_artifacts") |
| | p.add_argument("--langs", type=str, default="en,hi,mr,ta,bn,gu,kn,ml,pa,or,as,ur,sa,ne", help="comma-separated languages to merge (will stack)") |
| | p.add_argument("--text_prefix", type=str, default="question_", help="prefix for question columns in JSONL") |
| | p.add_argument("--answer_col_prefix", type=str, default="answer_", help="prefix for answer columns if present (not used here)") |
| | p.add_argument("--batch_size", type=int, default=64) |
| | p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") |
| | return p.parse_args() |
| |
|
| | def mean_pooling(last_hidden_state, attention_mask): |
| | |
| | |
| | input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() |
| | sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1) |
| | sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
| | return sum_embeddings / sum_mask |
| |
|
| | def build_question_answer_rows(df, langs, text_prefix): |
| | rows = [] |
| | for _, r in df.iterrows(): |
| | |
| | for lang in langs: |
| | qcol = f"{text_prefix}{lang}" |
| | acol = f"answer_{lang}" |
| | |
| | q = r.get(qcol, None) |
| | if q is None or str(q).strip() == "" or str(q).lower() == "nan": |
| | continue |
| | |
| | if acol in df.columns and pd.notna(r.get(acol)): |
| | a = r.get(acol) |
| | else: |
| | a = r.get("answer", None) |
| | if a is None or str(a).strip() == "" or str(a).lower() == "nan": |
| | continue |
| | rows.append({"question": str(q).strip(), "answer": str(a).strip(), "language": lang}) |
| | return pd.DataFrame(rows) |
| |
|
| | def main(): |
| | args = parse_args() |
| | os.makedirs(args.out_dir, exist_ok=True) |
| | |
| | print("Loading dataset:", args.input_jsonl) |
| | df_in = pd.read_json(args.input_jsonl, lines=True, dtype=str) |
| | |
| | langs = [l.strip() for l in args.langs.split(",") if l.strip()] |
| | print("Merging language columns (stack)... langs:", langs) |
| | rows_df = build_question_answer_rows(df_in, langs, args.text_prefix) |
| | if rows_df.empty: |
| | raise SystemExit("No question/answer rows found after merging languages. Check your columns.") |
| | print(f"Total rows extracted: {len(rows_df)}") |
| | |
| | csv_path = Path(args.out_dir) / "muril_multilingual_dataset.csv" |
| | rows_df.to_csv(csv_path, index=False, encoding="utf-8") |
| | print("Saved merged CSV to:", csv_path) |
| |
|
| | |
| | print("Loading tokenizer & model from:", args.model_dir, "device:", args.device) |
| | tokenizer = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True) |
| | model = AutoModel.from_pretrained(args.model_dir) |
| | model.to(args.device) |
| | model.eval() |
| |
|
| | |
| | answers = rows_df["answer"].astype(str).tolist() |
| | batch_size = int(args.batch_size) |
| | all_embs = [] |
| | with torch.inference_mode(): |
| | for i in tqdm(range(0, len(answers), batch_size), desc="Encoding"): |
| | batch_texts = answers[i:i+batch_size] |
| | encoded = tokenizer(batch_texts, padding=True, truncation=True, max_length=256, return_tensors="pt") |
| | input_ids = encoded["input_ids"].to(args.device) |
| | attention_mask = encoded["attention_mask"].to(args.device) |
| | out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) |
| | last_hidden = out.last_hidden_state |
| | pooled = mean_pooling(last_hidden, attention_mask) |
| | |
| | pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) |
| | all_embs.append(pooled.cpu()) |
| | all_embs = torch.cat(all_embs, dim=0) |
| | print("Embeddings shape:", all_embs.shape) |
| | embed_path = Path(args.out_dir) / "answer_embeddings.pt" |
| | torch.save(all_embs, embed_path) |
| | print("Saved embeddings to:", embed_path) |
| |
|
| | print("Done. Artifacts in:", args.out_dir) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|