import os import argparse from pathlib import Path from typing import List, Dict from datasets import Dataset from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, ) from peft import LoraConfig, get_peft_model BASE_DIR = Path(__file__).resolve().parent DATA_DIR = BASE_DIR / "data" def find_training_pairs(data_dir: Path) -> List[Dict[str, str]]: """Recorre las subcarpetas de data_dir y busca pares target_une_ad.srt / free_ad.txt. Cada ejemplo se formatea como una instrucción estilo instruct, usando el SRT como entrada y la narración libre como salida. """ examples: List[Dict[str, str]] = [] if not data_dir.exists(): raise FileNotFoundError(f"Data dir not found: {data_dir}") for item in sorted(data_dir.iterdir()): if not item.is_dir(): continue srt_path = item / "target_une_ad.srt" free_path = item / "free_ad.txt" if not srt_path.exists() or not free_path.exists(): continue srt_text = srt_path.read_text(encoding="utf-8") free_text = free_path.read_text(encoding="utf-8") # Formato tipo instruction-tuning, en catalán, coherente con la tarea prompt = ( "Converteix el següent fitxer SRT d'audiodescripció UNE (amb restriccions temporals) " "en una narració lliure detallada en català, sense límits de temps. " "Mantén tota la informació visual rellevant però amb un to fluid i natural.\n\n" "### SRT UNE\n" + srt_text.strip() + "\n\n### Narració lliure:" ) examples.append({"prompt": prompt, "output": free_text.strip()}) if not examples: raise RuntimeError(f"No training pairs found in {data_dir} (expected target_une_ad.srt + free_ad.txt)") return examples def build_dataset(pairs: List[Dict[str, str]], tokenizer: AutoTokenizer, max_length: int = 2048) -> Dataset: """Construye un Dataset de Hugging Face a partir de los pares prompt/output. Se concatena en una sola secuencia para entrenamiento causal: [PROMPT] + [OUTPUT] + eos y se enmascaran los tokens del prompt para que la loss sólo se compute sobre la salida. """ def _gen(): for ex in pairs: yield {"prompt": ex["prompt"], "output": ex["output"]} raw_ds = Dataset.from_generator(_gen) def tokenize_fn(batch): prompts = batch["prompt"] outputs = batch["output"] input_ids_list = [] labels_list = [] for p, o in zip(prompts, outputs): full_text = p + "\n" + o + tokenizer.eos_token enc = tokenizer( full_text, truncation=True, max_length=max_length, padding="max_length", ) # Máscara: ignorar loss en tokens del prompt prompt_ids = tokenizer(p + "\n", truncation=True, max_length=max_length)["input_ids"] prompt_len = min(len(prompt_ids), max_length) labels = enc["input_ids"].copy() for i in range(prompt_len): labels[i] = -100 input_ids_list.append(enc["input_ids"]) labels_list.append(labels) return {"input_ids": input_ids_list, "attention_mask": [([1] * max_length)] * len(input_ids_list), "labels": labels_list} tokenized = raw_ds.map(tokenize_fn, batched=True, remove_columns=["prompt", "output"]) return tokenized def create_lora_model(base_model_name: str, r: int = 16, alpha: int = 32, dropout: float = 0.05): model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype="auto", device_map="auto", ) lora_config = LoraConfig( r=r, lora_alpha=alpha, lora_dropout=dropout, bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, lora_config) return model def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Fine-tuning LoRA per a salamandra-instruct-7b amb dades UNE/free AD") parser.add_argument( "--base_model", type=str, default="projecte-aina/salamandra-instruct-7b", help="Nom o ruta del model base (HF hub o path local)", ) parser.add_argument( "--data_dir", type=str, default=str(DATA_DIR), help="Directori base amb subcarpetes que contenen target_une_ad.srt i free_ad.txt", ) parser.add_argument( "--output_dir", type=str, default=str(BASE_DIR / "lora_output"), help="Directori on desar l'adapter LoRA", ) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--gradient_accumulation", type=int, default=8) parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--lr", type=float, default=2e-4) parser.add_argument("--max_length", type=int, default=2048) parser.add_argument("--warmup_ratio", type=float, default=0.03) parser.add_argument("--logging_steps", type=int, default=10) parser.add_argument("--save_steps", type=int, default=200) parser.add_argument("--eval_steps", type=int, default=200) parser.add_argument("--r", type=int, default=16, help="Rank de LoRA") parser.add_argument("--alpha", type=int, default=32, help="Alpha de LoRA") parser.add_argument("--dropout", type=float, default=0.05, help="Dropout de LoRA") return parser.parse_args() def main(): args = parse_args() data_dir = Path(args.data_dir) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) print(f"[lora] Buscant dades a: {data_dir}") pairs = find_training_pairs(data_dir) print(f"[lora] Nombre d'exemples trobats: {len(pairs)}") print(f"[lora] Carregant tokenizer de {args.base_model}") tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("[lora] Construint dataset tokenitzat...") dataset = build_dataset(pairs, tokenizer, max_length=args.max_length) print(f"[lora] Carregant model base {args.base_model} i aplicant LoRA...") model = create_lora_model(args.base_model, r=args.r, alpha=args.alpha, dropout=args.dropout) training_args = TrainingArguments( output_dir=str(output_dir), per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation, num_train_epochs=args.epochs, learning_rate=args.lr, warmup_ratio=args.warmup_ratio, logging_steps=args.logging_steps, save_steps=args.save_steps, evaluation_strategy="steps", eval_steps=args.eval_steps, save_total_limit=2, bf16=True, gradient_checkpointing=True, report_to=[], ) trainer = Trainer( model=model, args=training_args, train_dataset=dataset, eval_dataset=None, tokenizer=tokenizer, ) print("[lora] Iniciant entrenament...") trainer.train() print("[lora] Guardant adapter LoRA...") model.save_pretrained(str(output_dir)) tokenizer.save_pretrained(str(output_dir)) print(f"[lora] Entrenament completat. Adapter guardat a {output_dir}") if __name__ == "__main__": main()