Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import hashlib | |
| import json | |
| import os | |
| import random | |
| import re | |
| from pathlib import Path | |
| import ctranslate2 | |
| def convert_base_model(model_name: str, output_dir: str, quantization: str) -> None: | |
| print(f"Converting model {model_name} to {output_dir} with {quantization} quantization...") | |
| converter = ctranslate2.converters.TransformersConverter( | |
| model_name, | |
| copy_files=["tokenizer.json", "sentencepiece.bpe.model"], | |
| ) | |
| converter.convert(output_dir, quantization=quantization, force=True) | |
| print(f"Model saved to {os.path.abspath(output_dir)}") | |
| def prepare_data(out_dir: str, seed: int, val_ratio: float, test_ratio: float, max_per_corpus: int) -> None: | |
| import warnings | |
| from datasets import load_dataset | |
| from datasets import concatenate_datasets | |
| warnings.filterwarnings( | |
| "ignore", | |
| message=r".*Helsinki-NLP/tatoeba_mt contains custom code.*", | |
| category=FutureWarning, | |
| ) | |
| def normalize(text: str) -> str: | |
| return re.sub(r"\s+", " ", text.strip()) | |
| def extract_value(item: dict, key: str) -> str: | |
| value = item | |
| for part in key.split("."): | |
| value = value[part] | |
| return value | |
| def to_rows(ds, src_key: str, tgt_key: str, name: str) -> list[dict]: | |
| rows = [] | |
| for item in ds: | |
| src = normalize(extract_value(item, src_key)) | |
| tgt = normalize(extract_value(item, tgt_key)) | |
| if not src or not tgt: | |
| continue | |
| rows.append({"source_text": src, "target_text": tgt, "source_lang": "eng_Latn", "target_lang": "ita_Latn", "dataset": name}) | |
| rows.append({"source_text": tgt, "target_text": src, "source_lang": "ita_Latn", "target_lang": "eng_Latn", "dataset": name}) | |
| return rows | |
| def load_tatoeba_rows(): | |
| """ | |
| Handle split changes in newer datasets releases where tatoeba_mt may expose | |
| validation/test only instead of train. | |
| """ | |
| try: | |
| return load_dataset("Helsinki-NLP/tatoeba_mt", "eng-ita", split="train") | |
| except ValueError: | |
| validation = load_dataset("Helsinki-NLP/tatoeba_mt", "eng-ita", split="validation") | |
| test = load_dataset("Helsinki-NLP/tatoeba_mt", "eng-ita", split="test") | |
| return concatenate_datasets([validation, test]) | |
| books = load_dataset("opus_books", "en-it", split="train") | |
| europarl = load_dataset("Helsinki-NLP/europarl", "en-it", split="train") | |
| tatoeba = load_tatoeba_rows() | |
| subs = load_dataset("open_subtitles", lang1="en", lang2="it", trust_remote_code=True, split="train") | |
| books = books.select(range(min(len(books), max_per_corpus))) | |
| europarl = europarl.select(range(min(len(europarl), max_per_corpus))) | |
| tatoeba = tatoeba.select(range(min(len(tatoeba), max_per_corpus))) | |
| subs = subs.select(range(min(len(subs), max_per_corpus * 4))) | |
| rows = [] | |
| rows.extend(to_rows(subs, "translation.en", "translation.it", "open_subtitles")) | |
| rows.extend(to_rows(books, "translation.en", "translation.it", "opus_books")) | |
| rows.extend(to_rows(europarl, "translation.en", "translation.it", "europarl")) | |
| rows.extend(to_rows(tatoeba, "sourceString", "targetString", "tatoeba")) | |
| deduped = [] | |
| seen = set() | |
| for row in rows: | |
| token_len_a = len(row["source_text"].split()) | |
| token_len_b = len(row["target_text"].split()) | |
| if token_len_a < 2 or token_len_b < 2 or token_len_a > 120 or token_len_b > 120: | |
| continue | |
| if max(token_len_a, token_len_b) / max(1, min(token_len_a, token_len_b)) > 3.0: | |
| continue | |
| key = f"{row['source_lang']}|{row['target_lang']}|{row['source_text']}|{row['target_text']}" | |
| digest = hashlib.sha1(key.encode("utf-8")).hexdigest() | |
| if digest in seen: | |
| continue | |
| seen.add(digest) | |
| deduped.append(row) | |
| rng = random.Random(seed) | |
| rng.shuffle(deduped) | |
| n_total = len(deduped) | |
| n_test = int(n_total * test_ratio) | |
| n_val = int(n_total * val_ratio) | |
| test_rows = deduped[:n_test] | |
| val_rows = deduped[n_test : n_test + n_val] | |
| train_rows = deduped[n_test + n_val :] | |
| path = Path(out_dir) | |
| path.mkdir(parents=True, exist_ok=True) | |
| def write_jsonl(name: str, samples: list[dict]) -> None: | |
| with (path / name).open("w", encoding="utf-8") as f: | |
| for row in samples: | |
| f.write(json.dumps(row, ensure_ascii=False) + "\n") | |
| write_jsonl("train.jsonl", train_rows) | |
| write_jsonl("val.jsonl", val_rows) | |
| write_jsonl("test.jsonl", test_rows) | |
| (path / "metadata.json").write_text( | |
| json.dumps( | |
| { | |
| "datasets": ["opus_books", "europarl", "tatoeba", "open_subtitles"], | |
| "seed": seed, | |
| "counts": {"train": len(train_rows), "val": len(val_rows), "test": len(test_rows), "total": n_total}, | |
| }, | |
| indent=2, | |
| ), | |
| encoding="utf-8", | |
| ) | |
| print(f"Wrote curated dataset to {path}") | |
| def train_lora( | |
| data_dir: str, | |
| output_dir: str, | |
| model_name: str, | |
| train_batch_size: int, | |
| eval_batch_size: int, | |
| gradient_accumulation_steps: int, | |
| max_length: int, | |
| num_train_epochs: float, | |
| resume_from_checkpoint: str | None, | |
| save_steps: int, | |
| eval_steps: int, | |
| logging_steps: int, | |
| bf16: bool, | |
| dataloader_num_workers: int, | |
| max_train_samples: int | None, | |
| max_eval_samples: int | None, | |
| eval_during_train: bool, | |
| final_eval: bool, | |
| ) -> None: | |
| import evaluate | |
| import numpy as np | |
| import torch | |
| import warnings | |
| from datasets import load_dataset | |
| from peft import LoraConfig, TaskType, get_peft_model | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments | |
| class _Seq2SeqTrainer(Seq2SeqTrainer): | |
| """HF Seq2SeqTrainer still reads `self.tokenizer` in `_pad_tensors_to_max_len`, which logs deprecation; use `processing_class`.""" | |
| def _pad_tensors_to_max_len(self, tensor, max_length): | |
| tok = self.processing_class | |
| if tok is not None and hasattr(tok, "pad_token_id"): | |
| pad_token_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id | |
| else: | |
| if self.model.config.pad_token_id is not None: | |
| pad_token_id = self.model.config.pad_token_id | |
| else: | |
| raise ValueError( | |
| "Pad_token_id must be set in the configuration of the model, in order to pad tensors" | |
| ) | |
| padded_tensor = pad_token_id * torch.ones( | |
| (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device | |
| ) | |
| padded_tensor[:, : tensor.shape[-1]] = tensor | |
| return padded_tensor | |
| if resume_from_checkpoint: | |
| # PyTorch 2.6+ defaults weights_only=True; HF checkpoint RNG state needs full unpickle (trusted local dirs only). | |
| # Force-disable weights_only behavior for resume to avoid RNG-state unpickling failures. | |
| os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = "0" | |
| os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" | |
| os.environ["HF_TRAINER_USE_WEIGHTS_ONLY_LOAD"] = "false" | |
| warnings.filterwarnings("ignore", category=FutureWarning, module=r"transformers\.trainer") | |
| dataset = load_dataset( | |
| "json", | |
| data_files={"train": str(Path(data_dir) / "train.jsonl"), "validation": str(Path(data_dir) / "val.jsonl")}, | |
| ) | |
| if max_train_samples is not None: | |
| dataset["train"] = dataset["train"].select(range(min(max_train_samples, len(dataset["train"])))) | |
| if max_eval_samples is not None: | |
| dataset["validation"] = dataset["validation"].select(range(min(max_eval_samples, len(dataset["validation"])))) | |
| processing_class = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| lora_cfg = LoraConfig( | |
| task_type=TaskType.SEQ_2_SEQ_LM, | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
| ) | |
| model = get_peft_model(model, lora_cfg) | |
| def preprocess(batch: dict) -> dict: | |
| tokenized = {"input_ids": [], "attention_mask": [], "labels": []} | |
| for src_text, tgt_text, src_lang, tgt_lang in zip( | |
| batch["source_text"], batch["target_text"], batch["source_lang"], batch["target_lang"] | |
| ): | |
| processing_class.src_lang = src_lang | |
| processing_class.tgt_lang = tgt_lang | |
| inputs = processing_class(src_text, max_length=max_length, truncation=True) | |
| labels = processing_class(text_target=tgt_text, max_length=max_length, truncation=True) | |
| tokenized["input_ids"].append(inputs["input_ids"]) | |
| tokenized["attention_mask"].append(inputs["attention_mask"]) | |
| tokenized["labels"].append(labels["input_ids"]) | |
| return tokenized | |
| tokenized = dataset.map(preprocess, batched=True, remove_columns=dataset["train"].column_names) | |
| collator = DataCollatorForSeq2Seq(tokenizer=processing_class, model=model) | |
| bleu_metric = evaluate.load("sacrebleu") | |
| chrf_metric = evaluate.load("chrf") | |
| def compute_metrics(eval_preds: tuple[np.ndarray, np.ndarray]) -> dict[str, float]: | |
| preds, labels = eval_preds | |
| if isinstance(preds, tuple): | |
| preds = preds[0] | |
| decoded_preds = processing_class.batch_decode(preds, skip_special_tokens=True) | |
| labels = np.where(labels != -100, labels, processing_class.pad_token_id) | |
| decoded_labels = processing_class.batch_decode(labels, skip_special_tokens=True) | |
| bleu = bleu_metric.compute(predictions=decoded_preds, references=[[x] for x in decoded_labels])["score"] | |
| chrf = chrf_metric.compute(predictions=decoded_preds, references=[[x] for x in decoded_labels])["score"] | |
| return {"bleu": round(bleu, 2), "chrf": round(chrf, 2)} | |
| train_args = Seq2SeqTrainingArguments( | |
| output_dir=str(Path(output_dir) / "checkpoints"), | |
| learning_rate=2e-4, | |
| per_device_train_batch_size=train_batch_size, | |
| per_device_eval_batch_size=eval_batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| num_train_epochs=num_train_epochs, | |
| eval_strategy="steps" if eval_during_train else "no", | |
| eval_steps=eval_steps, | |
| save_steps=save_steps, | |
| logging_strategy="steps", | |
| logging_steps=logging_steps, | |
| bf16=bf16, | |
| dataloader_num_workers=dataloader_num_workers, | |
| dataloader_pin_memory=True, | |
| dataloader_persistent_workers=dataloader_num_workers > 0, | |
| predict_with_generate=eval_during_train or final_eval, | |
| report_to="none", | |
| metric_for_best_model="bleu" if eval_during_train else None, | |
| greater_is_better=True, | |
| load_best_model_at_end=eval_during_train, | |
| ) | |
| trainer = _Seq2SeqTrainer( | |
| model=model, | |
| args=train_args, | |
| train_dataset=tokenized["train"], | |
| eval_dataset=tokenized["validation"] if (eval_during_train or final_eval) else None, | |
| processing_class=processing_class, | |
| data_collator=collator, | |
| compute_metrics=compute_metrics if (eval_during_train or final_eval) else None, | |
| ) | |
| trainer.train(resume_from_checkpoint=resume_from_checkpoint) | |
| metrics = trainer.evaluate() if final_eval else {} | |
| adapter_dir = Path(output_dir) / "adapter" | |
| adapter_dir.mkdir(parents=True, exist_ok=True) | |
| model.save_pretrained(adapter_dir) | |
| processing_class.save_pretrained(adapter_dir) | |
| (Path(output_dir) / "final_metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8") | |
| print(f"Saved LoRA adapter to {adapter_dir}") | |
| def export_lora(base_model: str, adapter_dir: str, output_dir: str, quantization: str) -> None: | |
| from peft import PeftModel | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| target = Path(output_dir) | |
| target.mkdir(parents=True, exist_ok=True) | |
| base = AutoModelForSeq2SeqLM.from_pretrained(base_model) | |
| merged = PeftModel.from_pretrained(base, adapter_dir).merge_and_unload() | |
| merged_hf = target / "merged_hf" | |
| merged.save_pretrained(merged_hf) | |
| AutoTokenizer.from_pretrained(base_model).save_pretrained(merged_hf) | |
| class _CompatTransformersConverter(ctranslate2.converters.TransformersConverter): | |
| """Bridge ctranslate2/transformers dtype kwarg compatibility across versions.""" | |
| def load_model(self, model_class, model_name_or_path, **kwargs): | |
| # Some ctranslate2 versions pass `dtype`, while transformers expects `torch_dtype`. | |
| if "dtype" in kwargs and "torch_dtype" not in kwargs: | |
| kwargs["torch_dtype"] = kwargs.pop("dtype") | |
| try: | |
| return model_class.from_pretrained(model_name_or_path, **kwargs) | |
| except TypeError as exc: | |
| # Fallback for older/newer transformers combinations that reject dtype args. | |
| if "torch_dtype" in kwargs: | |
| kwargs.pop("torch_dtype", None) | |
| return model_class.from_pretrained(model_name_or_path, **kwargs) | |
| raise exc | |
| converter = _CompatTransformersConverter(str(merged_hf)) | |
| converter.convert(str(target / "model"), quantization=quantization) | |
| print(f"Exported CTranslate2 model to {target / 'model'}") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="NMT utility script for conversion/data/LoRA workflows.") | |
| sub = parser.add_subparsers(dest="command", required=True) | |
| c = sub.add_parser("convert-base") | |
| c.add_argument("--model-name", default="facebook/nllb-200-distilled-600M") | |
| c.add_argument("--output-dir", default="nllb_int8") | |
| c.add_argument("--quantization", default="int8") | |
| d = sub.add_parser("prepare-data") | |
| d.add_argument("--out-dir", default="data/en_it_v1") | |
| d.add_argument("--seed", type=int, default=42) | |
| d.add_argument("--val-ratio", type=float, default=0.05) | |
| d.add_argument("--test-ratio", type=float, default=0.05) | |
| d.add_argument("--max-per-corpus", type=int, default=120000) | |
| t = sub.add_parser("train-lora") | |
| t.add_argument("--data-dir", default="data/en_it_v1") | |
| t.add_argument("--output-dir", default="artifacts/lora/en_it_v1") | |
| t.add_argument("--model-name", default="facebook/nllb-200-distilled-600M") | |
| t.add_argument("--train-batch-size", type=int, default=8) | |
| t.add_argument("--eval-batch-size", type=int, default=8) | |
| t.add_argument("--gradient-accumulation-steps", type=int, default=1) | |
| t.add_argument("--max-length", type=int, default=192) | |
| t.add_argument("--num-train-epochs", type=float, default=2.0) | |
| t.add_argument("--resume-from-checkpoint", default=None) | |
| t.add_argument("--save-steps", type=int, default=500) | |
| t.add_argument("--eval-steps", type=int, default=500) | |
| t.add_argument("--logging-steps", type=int, default=500) | |
| t.add_argument("--bf16", action=argparse.BooleanOptionalAction, default=True) | |
| t.add_argument("--dataloader-num-workers", type=int, default=4) | |
| t.add_argument("--max-train-samples", type=int, default=None) | |
| t.add_argument("--max-eval-samples", type=int, default=None) | |
| t.add_argument("--eval-during-train", action=argparse.BooleanOptionalAction, default=True) | |
| t.add_argument("--final-eval", action=argparse.BooleanOptionalAction, default=True) | |
| e = sub.add_parser("export-lora") | |
| e.add_argument("--base-model", default="facebook/nllb-200-distilled-600M") | |
| e.add_argument("--adapter-dir", required=True) | |
| e.add_argument("--output-dir", default="artifacts/ct2/en_it_lora_int8") | |
| e.add_argument("--quantization", default="int8") | |
| args = parser.parse_args() | |
| if args.command == "convert-base": | |
| convert_base_model(args.model_name, args.output_dir, args.quantization) | |
| elif args.command == "prepare-data": | |
| prepare_data(args.out_dir, args.seed, args.val_ratio, args.test_ratio, args.max_per_corpus) | |
| elif args.command == "train-lora": | |
| train_lora( | |
| args.data_dir, | |
| args.output_dir, | |
| args.model_name, | |
| args.train_batch_size, | |
| args.eval_batch_size, | |
| args.gradient_accumulation_steps, | |
| args.max_length, | |
| args.num_train_epochs, | |
| args.resume_from_checkpoint, | |
| args.save_steps, | |
| args.eval_steps, | |
| args.logging_steps, | |
| args.bf16, | |
| args.dataloader_num_workers, | |
| args.max_train_samples, | |
| args.max_eval_samples, | |
| args.eval_during_train, | |
| args.final_eval, | |
| ) | |
| else: | |
| export_lora(args.base_model, args.adapter_dir, args.output_dir, args.quantization) | |
| if __name__ == "__main__": | |
| main() | |