|
|
import os
|
|
|
import argparse
|
|
|
from pathlib import Path
|
|
|
from typing import List, Dict
|
|
|
|
|
|
from datasets import Dataset
|
|
|
from transformers import (
|
|
|
AutoTokenizer,
|
|
|
AutoModelForCausalLM,
|
|
|
TrainingArguments,
|
|
|
Trainer,
|
|
|
)
|
|
|
from peft import LoraConfig, get_peft_model
|
|
|
|
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
|
DATA_DIR = BASE_DIR / "data"
|
|
|
|
|
|
|
|
|
def find_training_pairs(data_dir: Path) -> List[Dict[str, str]]:
|
|
|
"""Recorre las subcarpetas de data_dir y busca pares target_une_ad.srt / free_ad.txt.
|
|
|
|
|
|
Cada ejemplo se formatea como una instrucci贸n estilo instruct, usando el SRT como entrada
|
|
|
y la narraci贸n libre como salida.
|
|
|
"""
|
|
|
examples: List[Dict[str, str]] = []
|
|
|
|
|
|
if not data_dir.exists():
|
|
|
raise FileNotFoundError(f"Data dir not found: {data_dir}")
|
|
|
|
|
|
for item in sorted(data_dir.iterdir()):
|
|
|
if not item.is_dir():
|
|
|
continue
|
|
|
|
|
|
srt_path = item / "target_une_ad.srt"
|
|
|
free_path = item / "free_ad.txt"
|
|
|
|
|
|
if not srt_path.exists() or not free_path.exists():
|
|
|
continue
|
|
|
|
|
|
srt_text = srt_path.read_text(encoding="utf-8")
|
|
|
free_text = free_path.read_text(encoding="utf-8")
|
|
|
|
|
|
|
|
|
prompt = (
|
|
|
"Converteix el seg眉ent fitxer SRT d'audiodescripci贸 UNE (amb restriccions temporals) "
|
|
|
"en una narraci贸 lliure detallada en catal脿, sense l铆mits de temps. "
|
|
|
"Mant茅n tota la informaci贸 visual rellevant per貌 amb un to fluid i natural.\n\n"
|
|
|
"### SRT UNE\n" + srt_text.strip() + "\n\n### Narraci贸 lliure:"
|
|
|
)
|
|
|
|
|
|
examples.append({"prompt": prompt, "output": free_text.strip()})
|
|
|
|
|
|
if not examples:
|
|
|
raise RuntimeError(f"No training pairs found in {data_dir} (expected target_une_ad.srt + free_ad.txt)")
|
|
|
|
|
|
return examples
|
|
|
|
|
|
|
|
|
def build_dataset(pairs: List[Dict[str, str]], tokenizer: AutoTokenizer, max_length: int = 2048) -> Dataset:
|
|
|
"""Construye un Dataset de Hugging Face a partir de los pares prompt/output.
|
|
|
|
|
|
Se concatena en una sola secuencia para entrenamiento causal:
|
|
|
[PROMPT] + [OUTPUT] + eos
|
|
|
y se enmascaran los tokens del prompt para que la loss s贸lo se compute sobre la salida.
|
|
|
"""
|
|
|
|
|
|
def _gen():
|
|
|
for ex in pairs:
|
|
|
yield {"prompt": ex["prompt"], "output": ex["output"]}
|
|
|
|
|
|
raw_ds = Dataset.from_generator(_gen)
|
|
|
|
|
|
def tokenize_fn(batch):
|
|
|
prompts = batch["prompt"]
|
|
|
outputs = batch["output"]
|
|
|
|
|
|
input_ids_list = []
|
|
|
labels_list = []
|
|
|
|
|
|
for p, o in zip(prompts, outputs):
|
|
|
full_text = p + "\n" + o + tokenizer.eos_token
|
|
|
enc = tokenizer(
|
|
|
full_text,
|
|
|
truncation=True,
|
|
|
max_length=max_length,
|
|
|
padding="max_length",
|
|
|
)
|
|
|
|
|
|
|
|
|
prompt_ids = tokenizer(p + "\n", truncation=True, max_length=max_length)["input_ids"]
|
|
|
prompt_len = min(len(prompt_ids), max_length)
|
|
|
|
|
|
labels = enc["input_ids"].copy()
|
|
|
for i in range(prompt_len):
|
|
|
labels[i] = -100
|
|
|
|
|
|
input_ids_list.append(enc["input_ids"])
|
|
|
labels_list.append(labels)
|
|
|
|
|
|
return {"input_ids": input_ids_list, "attention_mask": [([1] * max_length)] * len(input_ids_list), "labels": labels_list}
|
|
|
|
|
|
tokenized = raw_ds.map(tokenize_fn, batched=True, remove_columns=["prompt", "output"])
|
|
|
return tokenized
|
|
|
|
|
|
|
|
|
def create_lora_model(base_model_name: str, r: int = 16, alpha: int = 32, dropout: float = 0.05):
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
base_model_name,
|
|
|
torch_dtype="auto",
|
|
|
device_map="auto",
|
|
|
)
|
|
|
|
|
|
lora_config = LoraConfig(
|
|
|
r=r,
|
|
|
lora_alpha=alpha,
|
|
|
lora_dropout=dropout,
|
|
|
bias="none",
|
|
|
task_type="CAUSAL_LM",
|
|
|
)
|
|
|
|
|
|
model = get_peft_model(model, lora_config)
|
|
|
return model
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
|
parser = argparse.ArgumentParser(description="Fine-tuning LoRA per a salamandra-instruct-7b amb dades UNE/free AD")
|
|
|
parser.add_argument(
|
|
|
"--base_model",
|
|
|
type=str,
|
|
|
default="projecte-aina/salamandra-instruct-7b",
|
|
|
help="Nom o ruta del model base (HF hub o path local)",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--data_dir",
|
|
|
type=str,
|
|
|
default=str(DATA_DIR),
|
|
|
help="Directori base amb subcarpetes que contenen target_une_ad.srt i free_ad.txt",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--output_dir",
|
|
|
type=str,
|
|
|
default=str(BASE_DIR / "lora_output"),
|
|
|
help="Directori on desar l'adapter LoRA",
|
|
|
)
|
|
|
parser.add_argument("--batch_size", type=int, default=1)
|
|
|
parser.add_argument("--gradient_accumulation", type=int, default=8)
|
|
|
parser.add_argument("--epochs", type=int, default=3)
|
|
|
parser.add_argument("--lr", type=float, default=2e-4)
|
|
|
parser.add_argument("--max_length", type=int, default=2048)
|
|
|
parser.add_argument("--warmup_ratio", type=float, default=0.03)
|
|
|
parser.add_argument("--logging_steps", type=int, default=10)
|
|
|
parser.add_argument("--save_steps", type=int, default=200)
|
|
|
parser.add_argument("--eval_steps", type=int, default=200)
|
|
|
parser.add_argument("--r", type=int, default=16, help="Rank de LoRA")
|
|
|
parser.add_argument("--alpha", type=int, default=32, help="Alpha de LoRA")
|
|
|
parser.add_argument("--dropout", type=float, default=0.05, help="Dropout de LoRA")
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
def main():
|
|
|
args = parse_args()
|
|
|
|
|
|
data_dir = Path(args.data_dir)
|
|
|
output_dir = Path(args.output_dir)
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
print(f"[lora] Buscant dades a: {data_dir}")
|
|
|
pairs = find_training_pairs(data_dir)
|
|
|
print(f"[lora] Nombre d'exemples trobats: {len(pairs)}")
|
|
|
|
|
|
print(f"[lora] Carregant tokenizer de {args.base_model}")
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
print("[lora] Construint dataset tokenitzat...")
|
|
|
dataset = build_dataset(pairs, tokenizer, max_length=args.max_length)
|
|
|
|
|
|
print(f"[lora] Carregant model base {args.base_model} i aplicant LoRA...")
|
|
|
model = create_lora_model(args.base_model, r=args.r, alpha=args.alpha, dropout=args.dropout)
|
|
|
|
|
|
training_args = TrainingArguments(
|
|
|
output_dir=str(output_dir),
|
|
|
per_device_train_batch_size=args.batch_size,
|
|
|
gradient_accumulation_steps=args.gradient_accumulation,
|
|
|
num_train_epochs=args.epochs,
|
|
|
learning_rate=args.lr,
|
|
|
warmup_ratio=args.warmup_ratio,
|
|
|
logging_steps=args.logging_steps,
|
|
|
save_steps=args.save_steps,
|
|
|
evaluation_strategy="steps",
|
|
|
eval_steps=args.eval_steps,
|
|
|
save_total_limit=2,
|
|
|
bf16=True,
|
|
|
gradient_checkpointing=True,
|
|
|
report_to=[],
|
|
|
)
|
|
|
|
|
|
trainer = Trainer(
|
|
|
model=model,
|
|
|
args=training_args,
|
|
|
train_dataset=dataset,
|
|
|
eval_dataset=None,
|
|
|
tokenizer=tokenizer,
|
|
|
)
|
|
|
|
|
|
print("[lora] Iniciant entrenament...")
|
|
|
trainer.train()
|
|
|
|
|
|
print("[lora] Guardant adapter LoRA...")
|
|
|
model.save_pretrained(str(output_dir))
|
|
|
tokenizer.save_pretrained(str(output_dir))
|
|
|
|
|
|
print(f"[lora] Entrenament completat. Adapter guardat a {output_dir}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|