| import os
|
| import argparse
|
| from pathlib import Path
|
| from typing import List, Dict
|
|
|
| from datasets import Dataset
|
| from transformers import (
|
| AutoTokenizer,
|
| AutoModelForCausalLM,
|
| TrainingArguments,
|
| Trainer,
|
| )
|
| from peft import LoraConfig, get_peft_model
|
|
|
|
|
| BASE_DIR = Path(__file__).resolve().parent
|
| DATA_DIR = BASE_DIR / "data"
|
|
|
|
|
| def find_training_pairs(data_dir: Path) -> List[Dict[str, str]]:
|
| """Recorre las subcarpetas de data_dir y busca pares target_une_ad.srt / free_ad.txt.
|
|
|
| Cada ejemplo se formatea como una instrucci贸n estilo instruct, usando el SRT como entrada
|
| y la narraci贸n libre como salida.
|
| """
|
| examples: List[Dict[str, str]] = []
|
|
|
| if not data_dir.exists():
|
| raise FileNotFoundError(f"Data dir not found: {data_dir}")
|
|
|
| for item in sorted(data_dir.iterdir()):
|
| if not item.is_dir():
|
| continue
|
|
|
| srt_path = item / "target_une_ad.srt"
|
| free_path = item / "free_ad.txt"
|
|
|
| if not srt_path.exists() or not free_path.exists():
|
| continue
|
|
|
| srt_text = srt_path.read_text(encoding="utf-8")
|
| free_text = free_path.read_text(encoding="utf-8")
|
|
|
|
|
| prompt = (
|
| "Converteix el seg眉ent fitxer SRT d'audiodescripci贸 UNE (amb restriccions temporals) "
|
| "en una narraci贸 lliure detallada en catal脿, sense l铆mits de temps. "
|
| "Mant茅n tota la informaci贸 visual rellevant per貌 amb un to fluid i natural.\n\n"
|
| "### SRT UNE\n" + srt_text.strip() + "\n\n### Narraci贸 lliure:"
|
| )
|
|
|
| examples.append({"prompt": prompt, "output": free_text.strip()})
|
|
|
| if not examples:
|
| raise RuntimeError(f"No training pairs found in {data_dir} (expected target_une_ad.srt + free_ad.txt)")
|
|
|
| return examples
|
|
|
|
|
| def build_dataset(pairs: List[Dict[str, str]], tokenizer: AutoTokenizer, max_length: int = 2048) -> Dataset:
|
| """Construye un Dataset de Hugging Face a partir de los pares prompt/output.
|
|
|
| Se concatena en una sola secuencia para entrenamiento causal:
|
| [PROMPT] + [OUTPUT] + eos
|
| y se enmascaran los tokens del prompt para que la loss s贸lo se compute sobre la salida.
|
| """
|
|
|
| def _gen():
|
| for ex in pairs:
|
| yield {"prompt": ex["prompt"], "output": ex["output"]}
|
|
|
| raw_ds = Dataset.from_generator(_gen)
|
|
|
| def tokenize_fn(batch):
|
| prompts = batch["prompt"]
|
| outputs = batch["output"]
|
|
|
| input_ids_list = []
|
| labels_list = []
|
|
|
| for p, o in zip(prompts, outputs):
|
| full_text = p + "\n" + o + tokenizer.eos_token
|
| enc = tokenizer(
|
| full_text,
|
| truncation=True,
|
| max_length=max_length,
|
| padding="max_length",
|
| )
|
|
|
|
|
| prompt_ids = tokenizer(p + "\n", truncation=True, max_length=max_length)["input_ids"]
|
| prompt_len = min(len(prompt_ids), max_length)
|
|
|
| labels = enc["input_ids"].copy()
|
| for i in range(prompt_len):
|
| labels[i] = -100
|
|
|
| input_ids_list.append(enc["input_ids"])
|
| labels_list.append(labels)
|
|
|
| return {"input_ids": input_ids_list, "attention_mask": [([1] * max_length)] * len(input_ids_list), "labels": labels_list}
|
|
|
| tokenized = raw_ds.map(tokenize_fn, batched=True, remove_columns=["prompt", "output"])
|
| return tokenized
|
|
|
|
|
| def create_lora_model(base_model_name: str, r: int = 16, alpha: int = 32, dropout: float = 0.05):
|
| model = AutoModelForCausalLM.from_pretrained(
|
| base_model_name,
|
| torch_dtype="auto",
|
| device_map="auto",
|
| )
|
|
|
| lora_config = LoraConfig(
|
| r=r,
|
| lora_alpha=alpha,
|
| lora_dropout=dropout,
|
| bias="none",
|
| task_type="CAUSAL_LM",
|
| )
|
|
|
| model = get_peft_model(model, lora_config)
|
| return model
|
|
|
|
|
| def parse_args() -> argparse.Namespace:
|
| parser = argparse.ArgumentParser(description="Fine-tuning LoRA per a salamandra-instruct-7b amb dades UNE/free AD")
|
| parser.add_argument(
|
| "--base_model",
|
| type=str,
|
| default="projecte-aina/salamandra-instruct-7b",
|
| help="Nom o ruta del model base (HF hub o path local)",
|
| )
|
| parser.add_argument(
|
| "--data_dir",
|
| type=str,
|
| default=str(DATA_DIR),
|
| help="Directori base amb subcarpetes que contenen target_une_ad.srt i free_ad.txt",
|
| )
|
| parser.add_argument(
|
| "--output_dir",
|
| type=str,
|
| default=str(BASE_DIR / "lora_output"),
|
| help="Directori on desar l'adapter LoRA",
|
| )
|
| parser.add_argument("--batch_size", type=int, default=1)
|
| parser.add_argument("--gradient_accumulation", type=int, default=8)
|
| parser.add_argument("--epochs", type=int, default=3)
|
| parser.add_argument("--lr", type=float, default=2e-4)
|
| parser.add_argument("--max_length", type=int, default=2048)
|
| parser.add_argument("--warmup_ratio", type=float, default=0.03)
|
| parser.add_argument("--logging_steps", type=int, default=10)
|
| parser.add_argument("--save_steps", type=int, default=200)
|
| parser.add_argument("--eval_steps", type=int, default=200)
|
| parser.add_argument("--r", type=int, default=16, help="Rank de LoRA")
|
| parser.add_argument("--alpha", type=int, default=32, help="Alpha de LoRA")
|
| parser.add_argument("--dropout", type=float, default=0.05, help="Dropout de LoRA")
|
| return parser.parse_args()
|
|
|
|
|
| def main():
|
| args = parse_args()
|
|
|
| data_dir = Path(args.data_dir)
|
| output_dir = Path(args.output_dir)
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| print(f"[lora] Buscant dades a: {data_dir}")
|
| pairs = find_training_pairs(data_dir)
|
| print(f"[lora] Nombre d'exemples trobats: {len(pairs)}")
|
|
|
| print(f"[lora] Carregant tokenizer de {args.base_model}")
|
| tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
| print("[lora] Construint dataset tokenitzat...")
|
| dataset = build_dataset(pairs, tokenizer, max_length=args.max_length)
|
|
|
| print(f"[lora] Carregant model base {args.base_model} i aplicant LoRA...")
|
| model = create_lora_model(args.base_model, r=args.r, alpha=args.alpha, dropout=args.dropout)
|
|
|
| training_args = TrainingArguments(
|
| output_dir=str(output_dir),
|
| per_device_train_batch_size=args.batch_size,
|
| gradient_accumulation_steps=args.gradient_accumulation,
|
| num_train_epochs=args.epochs,
|
| learning_rate=args.lr,
|
| warmup_ratio=args.warmup_ratio,
|
| logging_steps=args.logging_steps,
|
| save_steps=args.save_steps,
|
| evaluation_strategy="steps",
|
| eval_steps=args.eval_steps,
|
| save_total_limit=2,
|
| bf16=True,
|
| gradient_checkpointing=True,
|
| report_to=[],
|
| )
|
|
|
| trainer = Trainer(
|
| model=model,
|
| args=training_args,
|
| train_dataset=dataset,
|
| eval_dataset=None,
|
| tokenizer=tokenizer,
|
| )
|
|
|
| print("[lora] Iniciant entrenament...")
|
| trainer.train()
|
|
|
| print("[lora] Guardant adapter LoRA...")
|
| model.save_pretrained(str(output_dir))
|
| tokenizer.save_pretrained(str(output_dir))
|
|
|
| print(f"[lora] Entrenament completat. Adapter guardat a {output_dir}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|