Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
os.system("pip install -U transformers peft accelerate trl bitsandbytes datasets diffusers")
|
| 3 |
os.system("pip install spaces-0.1.0-py3-none-any.whl")
|
| 4 |
-
import os
|
| 5 |
import io
|
| 6 |
import json
|
| 7 |
import tempfile
|
|
@@ -17,9 +16,6 @@ import re
|
|
| 17 |
import ast
|
| 18 |
from itertools import islice
|
| 19 |
from pathlib import Path
|
| 20 |
-
from collections import defaultdict
|
| 21 |
-
from datetime import datetime
|
| 22 |
-
|
| 23 |
import torch
|
| 24 |
import torch.nn.functional as F
|
| 25 |
from torch.utils.data import DataLoader
|
|
@@ -33,14 +29,15 @@ from langdetect import detect_langs
|
|
| 33 |
import textstat
|
| 34 |
from datasketch import MinHash, MinHashLSH
|
| 35 |
import gradio as gr
|
| 36 |
-
|
|
|
|
| 37 |
from huggingface_hub import login, whoami, create_repo, upload_folder, HfApi
|
| 38 |
from transformers import (
|
| 39 |
AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, Trainer,
|
| 40 |
AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer,
|
| 41 |
-
|
| 42 |
AutoImageProcessor, AutoModelForAudioClassification, AutoFeatureExtractor, AutoModelForTokenClassification,
|
| 43 |
-
DataCollatorForTokenClassification, AutoModelForQuestionAnswering,
|
| 44 |
AutoProcessor, DataCollatorWithPadding, pipeline, CLIPTextModel, CLIPTokenizer,
|
| 45 |
DataCollatorForSeq2Seq, AutoModelForSequenceClassification, BitsAndBytesConfig,
|
| 46 |
LlamaConfig, LlamaForCausalLM, MistralConfig, MistralForCausalLM, GemmaConfig, GemmaForCausalLM, GPT2Config, GPT2LMHeadModel,
|
|
@@ -57,8 +54,8 @@ from diffusers import (
|
|
| 57 |
)
|
| 58 |
import evaluate as hf_evaluate
|
| 59 |
from jinja2 import Template
|
|
|
|
| 60 |
|
| 61 |
-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 62 |
logger = logging.getLogger(__name__)
|
| 63 |
|
| 64 |
if torch.cuda.is_available():
|
|
@@ -97,8 +94,7 @@ TASK_TO_PIPELINE_MAP = {
|
|
| 97 |
"DreamBooth LoRA (Text-to-Image)": "text-to-image",
|
| 98 |
}
|
| 99 |
|
| 100 |
-
MODEL_CARD_TEMPLATE = """
|
| 101 |
-
---
|
| 102 |
language: es
|
| 103 |
license: apache-2.0
|
| 104 |
tags:
|
|
@@ -136,8 +132,7 @@ Este modelo es una versi贸n afinada de [{base_model}](https://huggingface.co/{ba
|
|
| 136 |
- Gradio
|
| 137 |
"""
|
| 138 |
|
| 139 |
-
DATASET_CARD_TEMPLATE = """
|
| 140 |
-
---
|
| 141 |
license: mit
|
| 142 |
---
|
| 143 |
|
|
@@ -169,52 +164,6 @@ class DebiasingSFTTrainer(SFTTrainer):
|
|
| 169 |
break
|
| 170 |
return (loss, outputs) if return_outputs else loss
|
| 171 |
|
| 172 |
-
class DeduplicatedIterableDataset(IterableDataset):
|
| 173 |
-
def __init__(self, dataset, text_col, method, threshold=0.85, num_perm=128):
|
| 174 |
-
super().__init__(ex_iterable=iter([]))
|
| 175 |
-
self.dataset = dataset
|
| 176 |
-
self.text_col = text_col
|
| 177 |
-
self.method = method
|
| 178 |
-
self.threshold = threshold
|
| 179 |
-
self.num_perm = num_perm
|
| 180 |
-
if hasattr(dataset, '_info'):
|
| 181 |
-
self._info = dataset._info
|
| 182 |
-
elif hasattr(dataset, 'info'):
|
| 183 |
-
self._info = dataset.info
|
| 184 |
-
|
| 185 |
-
def __iter__(self):
|
| 186 |
-
if self.method == 'Exacta':
|
| 187 |
-
return self._exact_iter()
|
| 188 |
-
elif self.method == 'Sem谩ntica (MinHash)':
|
| 189 |
-
return self._minhash_iter()
|
| 190 |
-
else:
|
| 191 |
-
return iter(self.dataset)
|
| 192 |
-
|
| 193 |
-
def _exact_iter(self):
|
| 194 |
-
seen_texts = set()
|
| 195 |
-
for example in self.dataset:
|
| 196 |
-
text = example.get(self.text_col, "")
|
| 197 |
-
if text and isinstance(text, str):
|
| 198 |
-
if text not in seen_texts:
|
| 199 |
-
seen_texts.add(text)
|
| 200 |
-
yield example
|
| 201 |
-
else:
|
| 202 |
-
yield example
|
| 203 |
-
|
| 204 |
-
def _minhash_iter(self):
|
| 205 |
-
lsh = MinHashLSH(threshold=self.threshold, num_perm=self.num_perm)
|
| 206 |
-
for i, example in enumerate(self.dataset):
|
| 207 |
-
text = example.get(self.text_col, "")
|
| 208 |
-
if text and isinstance(text, str) and text.strip():
|
| 209 |
-
m = MinHash(num_perm=self.num_perm)
|
| 210 |
-
for d in text.split():
|
| 211 |
-
m.update(d.encode('utf8'))
|
| 212 |
-
if not lsh.query(m):
|
| 213 |
-
lsh.insert(f"key_{i}", m)
|
| 214 |
-
yield example
|
| 215 |
-
else:
|
| 216 |
-
yield example
|
| 217 |
-
|
| 218 |
def hf_login(token):
|
| 219 |
if not token:
|
| 220 |
return "Por favor, introduce un token."
|
|
@@ -330,6 +279,8 @@ def _load_hf_streaming(ids, split="train", probabilities=None):
|
|
| 330 |
if probabilities and len(probabilities) != len(streams):
|
| 331 |
logger.warning(f"Number of probabilities ({len(probabilities)}) does not match number of valid datasets ({len(streams)}). Ignoring weights.")
|
| 332 |
probabilities = None
|
|
|
|
|
|
|
| 333 |
return interleave_datasets(streams, probabilities=probabilities)
|
| 334 |
|
| 335 |
def _load_uploaded_stream(files):
|
|
@@ -517,7 +468,6 @@ def _create_training_args(output_dir, repo_id, **kwargs):
|
|
| 517 |
"save_strategy": "steps",
|
| 518 |
"logging_steps": int(kwargs.get('logging_steps', 10)),
|
| 519 |
"save_steps": int(kwargs.get('save_steps', 50)),
|
| 520 |
-
"evaluation_strategy": "steps" if kwargs.get('run_evaluation', False) else "no",
|
| 521 |
"eval_steps": int(kwargs.get('save_steps', 50)) if kwargs.get('run_evaluation', False) else None,
|
| 522 |
"learning_rate": float(kwargs.get('learning_rate', 2e-5)),
|
| 523 |
"fp16": kwargs.get('mixed_precision') == 'fp16' and device == 'cuda',
|
|
@@ -727,7 +677,6 @@ def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_c
|
|
| 727 |
eval_logs = [log for log in trainer.state.log_history if 'eval_loss' in log]
|
| 728 |
if eval_logs:
|
| 729 |
final_metrics = eval_logs[-1]
|
| 730 |
-
final_metrics = {k.replace('eval_', ''): v for k, v in final_metrics.items()}
|
| 731 |
|
| 732 |
yield update_logs_fn("Entrenamiento finalizado.", "Guardando")
|
| 733 |
output_dir = trainer.args.output_dir
|
|
@@ -777,16 +726,16 @@ def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card
|
|
| 777 |
if kwargs.get('run_evaluation'):
|
| 778 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 779 |
for update in eval_dataset_gen:
|
| 780 |
-
if isinstance(update,
|
| 781 |
yield update
|
| 782 |
else:
|
| 783 |
eval_dataset = update
|
| 784 |
|
| 785 |
TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer)
|
| 786 |
-
trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "peft_config": peft_config, "tokenizer": tokenizer
|
| 787 |
|
| 788 |
if is_dpo:
|
| 789 |
-
trainer_kwargs.update({"beta": 0.1, "max_prompt_length": int(kwargs.get('block_size')) // 2})
|
| 790 |
if eval_dataset:
|
| 791 |
eval_dataset = eval_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs))
|
| 792 |
else:
|
|
@@ -794,8 +743,18 @@ def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card
|
|
| 794 |
trainer_kwargs.update({"formatting_func": lambda ex: _sft_formatting_func(example=ex, tokenizer=tokenizer, text_col=text_col, **sft_kwargs)})
|
| 795 |
if kwargs.get('enable_loss_reweighting'):
|
| 796 |
trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': kwargs.get('reweighting_factor', 2.0)})
|
| 797 |
-
|
| 798 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 799 |
final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs)
|
| 800 |
return final_model_path, final_metrics
|
| 801 |
|
|
@@ -812,26 +771,23 @@ def train_sequence_classification(model_name, train_dataset, repo_id, update_log
|
|
| 812 |
tokenizer_id = kwargs.get('tokenizer_name') or model_name
|
| 813 |
yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuraci贸n")
|
| 814 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
|
| 815 |
-
if tokenizer.pad_token is None:
|
| 816 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 817 |
|
| 818 |
yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuraci贸n")
|
| 819 |
model = _generic_model_loader(model_name, AutoModelForSequenceClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs)
|
| 820 |
-
model.config.pad_token_id = tokenizer.pad_token_id
|
| 821 |
|
| 822 |
def preprocess(examples):
|
| 823 |
return tokenizer(examples[kwargs['text_col']], truncation=True, max_length=512)
|
| 824 |
-
train_dataset = train_dataset.map(preprocess
|
| 825 |
|
| 826 |
eval_dataset = None
|
| 827 |
if kwargs.get('run_evaluation'):
|
| 828 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 829 |
for update in eval_dataset_gen:
|
| 830 |
-
if isinstance(update,
|
| 831 |
yield update
|
| 832 |
else:
|
| 833 |
eval_dataset = update
|
| 834 |
-
if eval_dataset: eval_dataset = eval_dataset.map(preprocess
|
| 835 |
|
| 836 |
metric = hf_evaluate.load("accuracy")
|
| 837 |
def compute_metrics(eval_pred):
|
|
@@ -887,7 +843,7 @@ def train_token_classification(model_name, train_dataset, repo_id, update_logs_f
|
|
| 887 |
if kwargs.get('run_evaluation'):
|
| 888 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 889 |
for update in eval_dataset_gen:
|
| 890 |
-
if isinstance(update,
|
| 891 |
yield update
|
| 892 |
else:
|
| 893 |
eval_dataset = update
|
|
@@ -979,7 +935,7 @@ def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn,
|
|
| 979 |
eval_dataset_raw_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 980 |
eval_dataset_raw = None
|
| 981 |
for update in eval_dataset_raw_gen:
|
| 982 |
-
if isinstance(update,
|
| 983 |
yield update
|
| 984 |
else:
|
| 985 |
eval_dataset_raw = update
|
|
@@ -1023,7 +979,7 @@ def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card
|
|
| 1023 |
if kwargs.get('run_evaluation'):
|
| 1024 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 1025 |
for update in eval_dataset_gen:
|
| 1026 |
-
if isinstance(update,
|
| 1027 |
yield update
|
| 1028 |
else:
|
| 1029 |
eval_dataset = update
|
|
@@ -1067,30 +1023,22 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
|
|
| 1067 |
|
| 1068 |
yield update_logs_fn("Configurando componentes de Diffusers...", "Text-to-Image (LoRA)")
|
| 1069 |
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
|
| 1070 |
-
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder"
|
| 1071 |
-
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae"
|
| 1072 |
-
unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet"
|
| 1073 |
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
|
| 1074 |
|
| 1075 |
vae.requires_grad_(False)
|
| 1076 |
text_encoder.requires_grad_(False)
|
| 1077 |
unet.train()
|
| 1078 |
|
| 1079 |
-
yield update_logs_fn("Agregando adaptadores LoRA al
|
| 1080 |
unet_lora_config = LoraConfig(
|
| 1081 |
r=int(kwargs.get('lora_r', 16)), lora_alpha=int(kwargs.get('lora_alpha', 32)),
|
| 1082 |
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
| 1083 |
)
|
| 1084 |
unet.add_adapter(unet_lora_config)
|
| 1085 |
|
| 1086 |
-
if kwargs.get('dreambooth_train_text_encoder', False):
|
| 1087 |
-
yield update_logs_fn("Agregando adaptadores LoRA al Text Encoder...", "DreamBooth LoRA")
|
| 1088 |
-
text_encoder_lora_config = LoraConfig(
|
| 1089 |
-
r=int(kwargs.get('lora_r', 16)), lora_alpha=int(kwargs.get('lora_alpha', 32)),
|
| 1090 |
-
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
| 1091 |
-
)
|
| 1092 |
-
text_encoder.add_adapter(text_encoder_lora_config)
|
| 1093 |
-
|
| 1094 |
yield update_logs_fn("Procesando dataset de im谩genes...", "Text-to-Image (LoRA)")
|
| 1095 |
resolution = int(kwargs.get('diffusion_resolution', 512))
|
| 1096 |
|
|
@@ -1102,7 +1050,7 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
|
|
| 1102 |
])
|
| 1103 |
|
| 1104 |
def preprocess_train(examples):
|
| 1105 |
-
images = [image.convert("RGB") for image in examples[kwargs.get('image_col', 'image')]]
|
| 1106 |
examples["pixel_values"] = [train_transforms(image) for image in images]
|
| 1107 |
examples["input_ids"] = tokenizer(examples[kwargs.get('text_col', 'text')], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
|
| 1108 |
return examples
|
|
@@ -1116,17 +1064,14 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
|
|
| 1116 |
|
| 1117 |
def collate_fn(examples):
|
| 1118 |
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 1119 |
-
input_ids = torch.stack([
|
| 1120 |
return {"pixel_values": pixel_values, "input_ids": input_ids}
|
| 1121 |
|
| 1122 |
train_dataloader = DataLoader(processed_dataset, shuffle=True, collate_fn=collate_fn, batch_size=int(kwargs.get('batch_size', 1)))
|
| 1123 |
-
|
| 1124 |
-
params_to_optimize = list(unet.parameters())
|
| 1125 |
-
if kwargs.get('dreambooth_train_text_encoder', False):
|
| 1126 |
-
params_to_optimize += list(text_encoder.parameters())
|
| 1127 |
|
|
|
|
| 1128 |
optimizer = torch.optim.AdamW(
|
| 1129 |
-
|
| 1130 |
betas=(float(kwargs.get('adam_beta1', 0.9)), float(kwargs.get('adam_beta2', 0.999))),
|
| 1131 |
weight_decay=float(kwargs.get('weight_decay', 0.01)),
|
| 1132 |
eps=float(kwargs.get('adam_epsilon', 1e-8)),
|
|
@@ -1142,34 +1087,36 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
|
|
| 1142 |
num_training_steps=max_train_steps,
|
| 1143 |
)
|
| 1144 |
|
| 1145 |
-
unet,
|
| 1146 |
-
unet,
|
| 1147 |
)
|
| 1148 |
|
|
|
|
| 1149 |
vae.to(accelerator.device, dtype=torch_dtype_auto)
|
| 1150 |
|
|
|
|
| 1151 |
global_step = 0
|
| 1152 |
final_loss = 0
|
| 1153 |
for epoch in range(num_epochs):
|
| 1154 |
for step, batch in enumerate(train_dataloader):
|
| 1155 |
with accelerator.accumulate(unet):
|
| 1156 |
-
latents = vae.encode(batch["pixel_values"].to(
|
| 1157 |
latents = latents * vae.config.scaling_factor
|
| 1158 |
noise = torch.randn_like(latents)
|
| 1159 |
bsz = latents.shape[0]
|
| 1160 |
-
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
|
|
|
|
|
|
| 1161 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 1162 |
-
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
|
|
|
| 1163 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 1164 |
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
| 1165 |
final_loss = loss.detach().item()
|
| 1166 |
|
| 1167 |
accelerator.backward(loss)
|
| 1168 |
if accelerator.sync_gradients:
|
| 1169 |
-
|
| 1170 |
-
if kwargs.get('dreambooth_train_text_encoder', False):
|
| 1171 |
-
params_to_clip += list(text_encoder.parameters())
|
| 1172 |
-
accelerator.clip_grad_norm_(params_to_clip, float(kwargs.get('max_grad_norm', 1.0)))
|
| 1173 |
|
| 1174 |
optimizer.step()
|
| 1175 |
lr_scheduler.step()
|
|
@@ -1177,21 +1124,16 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
|
|
| 1177 |
|
| 1178 |
if accelerator.is_main_process:
|
| 1179 |
if global_step % int(kwargs.get('logging_steps', 10)) == 0:
|
| 1180 |
-
yield update_logs_fn(f"Epoch {epoch}, Step {step}, Loss: {final_loss
|
| 1181 |
global_step += 1
|
| 1182 |
-
if global_step >= max_train_steps:
|
| 1183 |
-
break
|
| 1184 |
-
if global_step >= max_train_steps:
|
| 1185 |
-
break
|
| 1186 |
|
|
|
|
| 1187 |
accelerator.wait_for_everyone()
|
| 1188 |
if accelerator.is_main_process:
|
| 1189 |
-
|
| 1190 |
-
|
| 1191 |
-
|
| 1192 |
-
|
| 1193 |
-
torch_dtype=torch_dtype_auto,
|
| 1194 |
-
)
|
| 1195 |
pipeline.save_pretrained(output_dir)
|
| 1196 |
|
| 1197 |
with open(os.path.join(output_dir, "README.md"), "w", encoding="utf-8") as f:
|
|
@@ -1206,6 +1148,7 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
|
|
| 1206 |
torch.cuda.empty_cache()
|
| 1207 |
return output_dir, {"final_loss": final_loss}
|
| 1208 |
|
|
|
|
| 1209 |
def train_dreambooth_lora(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 1210 |
if device == 'cpu':
|
| 1211 |
raise ValueError("El entrenamiento de DreamBooth solo es compatible con GPU CUDA.")
|
|
@@ -1220,7 +1163,7 @@ def train_dreambooth_lora(model_name, train_dataset, repo_id, update_logs_fn, mo
|
|
| 1220 |
|
| 1221 |
train_dataset = train_dataset.map(add_prompt)
|
| 1222 |
|
| 1223 |
-
yield update_logs_fn(f"Usando el prompt de instancia para todas las im谩genes: '{dreambooth_prompt}'", "DreamBooth LoRA")
|
| 1224 |
|
| 1225 |
final_model_path, final_metrics = yield from train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs)
|
| 1226 |
return final_model_path, final_metrics
|
|
@@ -1251,8 +1194,9 @@ def _get_data_processing_pipeline(**kwargs):
|
|
| 1251 |
if train_dataset is None:
|
| 1252 |
train_dataset = hf_train_dataset
|
| 1253 |
else:
|
|
|
|
| 1254 |
all_streams = [train_dataset, hf_train_dataset]
|
| 1255 |
-
all_probs = [0.5, 0.5]
|
| 1256 |
train_dataset = interleave_datasets(all_streams, probabilities=all_probs)
|
| 1257 |
|
| 1258 |
if train_dataset is None:
|
|
@@ -1262,8 +1206,7 @@ def _get_data_processing_pipeline(**kwargs):
|
|
| 1262 |
text_col, image_col, audio_col, label_col = _guess_columns(first_example)
|
| 1263 |
kwargs.update({'text_col': text_col, 'image_col': image_col, 'audio_col': audio_col, 'label_col': label_col, 'uploaded_val_data': uploaded_val_data})
|
| 1264 |
|
| 1265 |
-
|
| 1266 |
-
if is_text_task:
|
| 1267 |
if any([kwargs.get('remove_html_tags'), kwargs.get('normalize_whitespace'), kwargs.get('remove_urls_emails'), kwargs.get('redact_pii')]):
|
| 1268 |
clean_kwargs = {k:v for k,v in kwargs.items() if k in ['remove_html_tags', 'normalize_whitespace', 'remove_urls_emails', 'redact_pii']}
|
| 1269 |
train_dataset = train_dataset.map(lambda ex: _clean_text(ex, text_col, **clean_kwargs))
|
|
@@ -1286,13 +1229,35 @@ def _get_data_processing_pipeline(**kwargs):
|
|
| 1286 |
|
| 1287 |
dedup_method = kwargs.get('deduplication_method')
|
| 1288 |
if dedup_method != 'Ninguna':
|
| 1289 |
-
|
| 1290 |
-
|
| 1291 |
-
|
| 1292 |
-
|
| 1293 |
-
|
| 1294 |
-
|
| 1295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1296 |
|
| 1297 |
return train_dataset, kwargs
|
| 1298 |
|
|
@@ -1385,6 +1350,7 @@ def _train_and_upload(**kwargs):
|
|
| 1385 |
raise Exception(f"No se pudo cargar el tokenizer base '{tokenizer_id}' para el modelo desde cero: {e}")
|
| 1386 |
base_model_id_for_training = temp_model_dir
|
| 1387 |
kwargs["peft"] = False
|
|
|
|
| 1388 |
kwargs['tokenizer_name'] = temp_model_dir
|
| 1389 |
yield update_logs(f"Modelo {architecture} inicializado en {temp_model_dir}.", "Modelo Cero") + (gr.update(), gr.update())
|
| 1390 |
|
|
@@ -1397,6 +1363,7 @@ def _train_and_upload(**kwargs):
|
|
| 1397 |
os.environ["WANDB_PROJECT"] = kwargs.get('wandb_project_input') or f"{repo_base}"
|
| 1398 |
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
|
| 1399 |
|
|
|
|
| 1400 |
model_card_content = MODEL_CARD_TEMPLATE.format(
|
| 1401 |
repo_id=repo_id, base_model=model_name, base_model_name=model_name.split('/')[-1],
|
| 1402 |
training_mode=kwargs.get('training_mode'),
|
|
@@ -1423,11 +1390,8 @@ def _train_and_upload(**kwargs):
|
|
| 1423 |
train_generator = train_func(base_model_id_for_training, train_dataset, repo_id, update_logs, model_card_content, **kwargs)
|
| 1424 |
while True:
|
| 1425 |
try:
|
| 1426 |
-
|
| 1427 |
-
|
| 1428 |
-
yield update + (gr.update(), gr.update())
|
| 1429 |
-
else:
|
| 1430 |
-
pass
|
| 1431 |
except StopIteration as e:
|
| 1432 |
final_model_path, final_metrics = e.value
|
| 1433 |
break
|
|
@@ -1441,7 +1405,7 @@ def _train_and_upload(**kwargs):
|
|
| 1441 |
eval_dataset_perp = None
|
| 1442 |
eval_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), lambda m, p: update_logs(m, p))
|
| 1443 |
for update in eval_gen:
|
| 1444 |
-
if isinstance(update,
|
| 1445 |
yield update + (gr.update(), gr.update())
|
| 1446 |
else:
|
| 1447 |
eval_dataset_perp = update
|
|
@@ -1561,6 +1525,7 @@ def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, s
|
|
| 1561 |
for item in all_data:
|
| 1562 |
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
| 1563 |
|
|
|
|
| 1564 |
readme_content = DATASET_CARD_TEMPLATE.format(
|
| 1565 |
repo_id=repo_id,
|
| 1566 |
creation_type=creation_type,
|
|
@@ -1598,13 +1563,9 @@ def gradio_preview_data_wrapper(*args):
|
|
| 1598 |
dataset, processed_kwargs = _get_data_processing_pipeline(**kwargs)
|
| 1599 |
text_col = processed_kwargs.get('text_col')
|
| 1600 |
|
| 1601 |
-
model_id_for_tokenizer = kwargs.get('model_base_input')
|
| 1602 |
-
if not model_id_for_tokenizer:
|
| 1603 |
-
raise ValueError("Se necesita un ID de modelo base para cargar el tokenizer para la vista previa.")
|
| 1604 |
-
|
| 1605 |
-
tokenizer_id = kwargs.get('tokenizer_name') or model_id_for_tokenizer
|
| 1606 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 1607 |
-
|
|
|
|
| 1608 |
)
|
| 1609 |
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
| 1610 |
if kwargs.get('chat_template_jinja', '').strip(): tokenizer.chat_template = kwargs['chat_template_jinja']
|
|
@@ -1613,15 +1574,15 @@ def gradio_preview_data_wrapper(*args):
|
|
| 1613 |
for i, example in enumerate(islice(dataset, 5)):
|
| 1614 |
formatted_text = ""
|
| 1615 |
if kwargs['training_mode'] == "DPO (Direct Preference Optimization)":
|
| 1616 |
-
formatted_text = json.dumps(_dpo_formatting_func(example, **kwargs), indent=2
|
| 1617 |
else:
|
| 1618 |
formatted_text = _sft_formatting_func(example, text_col, tokenizer, **kwargs)
|
| 1619 |
|
| 1620 |
preview_samples.append(f"--- MUESTRA {i+1} ---\n{formatted_text}\n")
|
| 1621 |
|
| 1622 |
preview_text = "\n".join(preview_samples)
|
| 1623 |
-
if not
|
| 1624 |
-
preview_text = "No se pudieron generar muestras. Revisa la configuraci贸n del dataset
|
| 1625 |
yield preview_text
|
| 1626 |
|
| 1627 |
except Exception as e:
|
|
@@ -1643,7 +1604,6 @@ def toggle_task_specific_ui(training_mode):
|
|
| 1643 |
is_sft = "Causal" in training_mode
|
| 1644 |
is_ner = "Token Classification" in training_mode
|
| 1645 |
is_diffusion = training_mode in ["Text-to-Image (LoRA)", "DreamBooth LoRA (Text-to-Image)"]
|
| 1646 |
-
is_streaming = not is_diffusion
|
| 1647 |
|
| 1648 |
return (
|
| 1649 |
gr.update(visible=is_classification or is_ner),
|
|
@@ -1653,10 +1613,10 @@ def toggle_task_specific_ui(training_mode):
|
|
| 1653 |
gr.update(visible=training_mode == "DreamBooth LoRA (Text-to-Image)"),
|
| 1654 |
gr.update(visible=not is_diffusion),
|
| 1655 |
gr.update(visible=is_diffusion),
|
| 1656 |
-
gr.update(visible=not
|
| 1657 |
-
gr.update(visible=is_streaming)
|
| 1658 |
)
|
| 1659 |
|
|
|
|
| 1660 |
def toggle_auto_modules_ui(is_auto):
|
| 1661 |
return gr.update(visible=not is_auto)
|
| 1662 |
|
|
@@ -1690,7 +1650,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1690 |
dset_file_uploads = gr.File(label="Subir Archivos (.jsonl, .csv, .txt)", file_count="multiple")
|
| 1691 |
dset_create_button = gr.Button("Crear y Subir Dataset", variant="primary")
|
| 1692 |
with gr.Column(scale=2):
|
| 1693 |
-
dset_status_output = gr.Textbox(label="Estado", lines=10
|
| 1694 |
dset_link_output = gr.Markdown()
|
| 1695 |
|
| 1696 |
dset_creation_type.change(toggle_dataset_creator_ui, inputs=[dset_creation_type], outputs=[dset_synth_group, dset_file_group])
|
|
@@ -1746,7 +1706,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1746 |
with gr.Accordion("Avanzados", open=False):
|
| 1747 |
warmup_ratio = gr.Slider(0.0, 0.5, 0.03, label="Ratio de Calentamiento")
|
| 1748 |
weight_decay = gr.Textbox(label="Decaimiento de Peso", value="0.01")
|
| 1749 |
-
max_grad_norm = gr.Textbox(label="Norma M谩xima de Gradiente", value="
|
| 1750 |
logging_steps = gr.Textbox(label="Pasos de Registro", value="10")
|
| 1751 |
save_steps = gr.Textbox(label="Pasos de Guardado", value="50")
|
| 1752 |
save_total_limit = gr.Textbox(label="L铆mite Total de Guardado", value="1")
|
|
@@ -1806,6 +1766,9 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1806 |
diffusion_resolution = gr.Slider(256, 1024, 512, step=64, label="Resoluci贸n")
|
| 1807 |
with gr.Group(visible=False) as dreambooth_ui:
|
| 1808 |
dreambooth_instance_prompt = gr.Textbox(label="Prompt de Instancia", placeholder="p.ej. 'foto de perro sks'")
|
|
|
|
|
|
|
|
|
|
| 1809 |
dreambooth_train_text_encoder = gr.Checkbox(label="Entrenar Text Encoder", value=True)
|
| 1810 |
with gr.Group(visible=False) as classification_labels_ui:
|
| 1811 |
classification_labels = gr.Textbox(label="Etiquetas de Clasificaci贸n (csv)", placeholder="p.ej. positivo,negativo")
|
|
@@ -1824,6 +1787,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1824 |
enable_cda = gr.Checkbox(label="Habilitar Aumentaci贸n Contrafactual (CDA)", value=False)
|
| 1825 |
cda_json_config = gr.Textbox(label="Configuraci贸n CDA (JSON)", placeholder='[["ella", "茅l"], ["mujer", "hombre"]]')
|
| 1826 |
|
|
|
|
| 1827 |
with gr.Accordion("馃攲 Integraciones", open=False):
|
| 1828 |
wandb_api_key_input = gr.Textbox(label="Clave API de W&B", type="password")
|
| 1829 |
wandb_project_input = gr.Textbox(label="Proyecto W&B")
|
|
@@ -1868,7 +1832,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1868 |
"diffusion_resolution": diffusion_resolution, "run_evaluation": run_evaluation, "run_perplexity_evaluation": run_perplexity_evaluation,
|
| 1869 |
"enable_loss_reweighting": enable_loss_reweighting, "reweighting_terms": reweighting_terms,
|
| 1870 |
"wandb_api_key_input": wandb_api_key_input, "wandb_project_input": wandb_project_input,
|
| 1871 |
-
"dreambooth_instance_prompt": dreambooth_instance_prompt,
|
|
|
|
| 1872 |
"dreambooth_train_text_encoder": dreambooth_train_text_encoder
|
| 1873 |
}
|
| 1874 |
|
|
@@ -1940,4 +1905,4 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1940 |
outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
|
| 1941 |
)
|
| 1942 |
|
| 1943 |
-
demo.queue().launch(
|
|
|
|
| 1 |
import os
|
| 2 |
os.system("pip install -U transformers peft accelerate trl bitsandbytes datasets diffusers")
|
| 3 |
os.system("pip install spaces-0.1.0-py3-none-any.whl")
|
|
|
|
| 4 |
import io
|
| 5 |
import json
|
| 6 |
import tempfile
|
|
|
|
| 16 |
import ast
|
| 17 |
from itertools import islice
|
| 18 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
| 19 |
import torch
|
| 20 |
import torch.nn.functional as F
|
| 21 |
from torch.utils.data import DataLoader
|
|
|
|
| 29 |
import textstat
|
| 30 |
from datasketch import MinHash, MinHashLSH
|
| 31 |
import gradio as gr
|
| 32 |
+
import spaces
|
| 33 |
+
from datasets import load_dataset, IterableDataset, Dataset, DatasetDict
|
| 34 |
from huggingface_hub import login, whoami, create_repo, upload_folder, HfApi
|
| 35 |
from transformers import (
|
| 36 |
AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, Trainer,
|
| 37 |
AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer,
|
| 38 |
+
AutoModelForImageClassification,
|
| 39 |
AutoImageProcessor, AutoModelForAudioClassification, AutoFeatureExtractor, AutoModelForTokenClassification,
|
| 40 |
+
DataCollatorForTokenClassification, AutoModelForQuestionAnswering,
|
| 41 |
AutoProcessor, DataCollatorWithPadding, pipeline, CLIPTextModel, CLIPTokenizer,
|
| 42 |
DataCollatorForSeq2Seq, AutoModelForSequenceClassification, BitsAndBytesConfig,
|
| 43 |
LlamaConfig, LlamaForCausalLM, MistralConfig, MistralForCausalLM, GemmaConfig, GemmaForCausalLM, GPT2Config, GPT2LMHeadModel,
|
|
|
|
| 54 |
)
|
| 55 |
import evaluate as hf_evaluate
|
| 56 |
from jinja2 import Template
|
| 57 |
+
from collections import defaultdict
|
| 58 |
|
|
|
|
| 59 |
logger = logging.getLogger(__name__)
|
| 60 |
|
| 61 |
if torch.cuda.is_available():
|
|
|
|
| 94 |
"DreamBooth LoRA (Text-to-Image)": "text-to-image",
|
| 95 |
}
|
| 96 |
|
| 97 |
+
MODEL_CARD_TEMPLATE = """---
|
|
|
|
| 98 |
language: es
|
| 99 |
license: apache-2.0
|
| 100 |
tags:
|
|
|
|
| 132 |
- Gradio
|
| 133 |
"""
|
| 134 |
|
| 135 |
+
DATASET_CARD_TEMPLATE = """---
|
|
|
|
| 136 |
license: mit
|
| 137 |
---
|
| 138 |
|
|
|
|
| 164 |
break
|
| 165 |
return (loss, outputs) if return_outputs else loss
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
def hf_login(token):
|
| 168 |
if not token:
|
| 169 |
return "Por favor, introduce un token."
|
|
|
|
| 279 |
if probabilities and len(probabilities) != len(streams):
|
| 280 |
logger.warning(f"Number of probabilities ({len(probabilities)}) does not match number of valid datasets ({len(streams)}). Ignoring weights.")
|
| 281 |
probabilities = None
|
| 282 |
+
|
| 283 |
+
from datasets import interleave_datasets
|
| 284 |
return interleave_datasets(streams, probabilities=probabilities)
|
| 285 |
|
| 286 |
def _load_uploaded_stream(files):
|
|
|
|
| 468 |
"save_strategy": "steps",
|
| 469 |
"logging_steps": int(kwargs.get('logging_steps', 10)),
|
| 470 |
"save_steps": int(kwargs.get('save_steps', 50)),
|
|
|
|
| 471 |
"eval_steps": int(kwargs.get('save_steps', 50)) if kwargs.get('run_evaluation', False) else None,
|
| 472 |
"learning_rate": float(kwargs.get('learning_rate', 2e-5)),
|
| 473 |
"fp16": kwargs.get('mixed_precision') == 'fp16' and device == 'cuda',
|
|
|
|
| 677 |
eval_logs = [log for log in trainer.state.log_history if 'eval_loss' in log]
|
| 678 |
if eval_logs:
|
| 679 |
final_metrics = eval_logs[-1]
|
|
|
|
| 680 |
|
| 681 |
yield update_logs_fn("Entrenamiento finalizado.", "Guardando")
|
| 682 |
output_dir = trainer.args.output_dir
|
|
|
|
| 726 |
if kwargs.get('run_evaluation'):
|
| 727 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 728 |
for update in eval_dataset_gen:
|
| 729 |
+
if isinstance(update, tuple):
|
| 730 |
yield update
|
| 731 |
else:
|
| 732 |
eval_dataset = update
|
| 733 |
|
| 734 |
TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer)
|
| 735 |
+
trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "peft_config": peft_config, "tokenizer": tokenizer}
|
| 736 |
|
| 737 |
if is_dpo:
|
| 738 |
+
trainer_kwargs.update({"beta": 0.1, "max_length": int(kwargs.get('block_size')), "max_prompt_length": int(kwargs.get('block_size')) // 2})
|
| 739 |
if eval_dataset:
|
| 740 |
eval_dataset = eval_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs))
|
| 741 |
else:
|
|
|
|
| 743 |
trainer_kwargs.update({"formatting_func": lambda ex: _sft_formatting_func(example=ex, tokenizer=tokenizer, text_col=text_col, **sft_kwargs)})
|
| 744 |
if kwargs.get('enable_loss_reweighting'):
|
| 745 |
trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': kwargs.get('reweighting_factor', 2.0)})
|
| 746 |
+
|
| 747 |
+
try:
|
| 748 |
+
trainer = TrainerClass(**trainer_kwargs)
|
| 749 |
+
except TypeError as e:
|
| 750 |
+
if "unexpected keyword argument 'tokenizer'" in str(e):
|
| 751 |
+
logger.warning("Caught TypeError for tokenizer argument. Retrying without it for TRL compatibility.")
|
| 752 |
+
trainer_kwargs.pop("tokenizer", None)
|
| 753 |
+
trainer = TrainerClass(**trainer_kwargs)
|
| 754 |
+
trainer.tokenizer = tokenizer
|
| 755 |
+
else:
|
| 756 |
+
raise e
|
| 757 |
+
|
| 758 |
final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs)
|
| 759 |
return final_model_path, final_metrics
|
| 760 |
|
|
|
|
| 771 |
tokenizer_id = kwargs.get('tokenizer_name') or model_name
|
| 772 |
yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuraci贸n")
|
| 773 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
|
|
|
|
|
|
|
| 774 |
|
| 775 |
yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuraci贸n")
|
| 776 |
model = _generic_model_loader(model_name, AutoModelForSequenceClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs)
|
|
|
|
| 777 |
|
| 778 |
def preprocess(examples):
|
| 779 |
return tokenizer(examples[kwargs['text_col']], truncation=True, max_length=512)
|
| 780 |
+
train_dataset = train_dataset.map(preprocess)
|
| 781 |
|
| 782 |
eval_dataset = None
|
| 783 |
if kwargs.get('run_evaluation'):
|
| 784 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 785 |
for update in eval_dataset_gen:
|
| 786 |
+
if isinstance(update, tuple):
|
| 787 |
yield update
|
| 788 |
else:
|
| 789 |
eval_dataset = update
|
| 790 |
+
if eval_dataset: eval_dataset = eval_dataset.map(preprocess)
|
| 791 |
|
| 792 |
metric = hf_evaluate.load("accuracy")
|
| 793 |
def compute_metrics(eval_pred):
|
|
|
|
| 843 |
if kwargs.get('run_evaluation'):
|
| 844 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 845 |
for update in eval_dataset_gen:
|
| 846 |
+
if isinstance(update, tuple):
|
| 847 |
yield update
|
| 848 |
else:
|
| 849 |
eval_dataset = update
|
|
|
|
| 935 |
eval_dataset_raw_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 936 |
eval_dataset_raw = None
|
| 937 |
for update in eval_dataset_raw_gen:
|
| 938 |
+
if isinstance(update, tuple):
|
| 939 |
yield update
|
| 940 |
else:
|
| 941 |
eval_dataset_raw = update
|
|
|
|
| 979 |
if kwargs.get('run_evaluation'):
|
| 980 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 981 |
for update in eval_dataset_gen:
|
| 982 |
+
if isinstance(update, tuple):
|
| 983 |
yield update
|
| 984 |
else:
|
| 985 |
eval_dataset = update
|
|
|
|
| 1023 |
|
| 1024 |
yield update_logs_fn("Configurando componentes de Diffusers...", "Text-to-Image (LoRA)")
|
| 1025 |
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
|
| 1026 |
+
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder")
|
| 1027 |
+
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae")
|
| 1028 |
+
unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet")
|
| 1029 |
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
|
| 1030 |
|
| 1031 |
vae.requires_grad_(False)
|
| 1032 |
text_encoder.requires_grad_(False)
|
| 1033 |
unet.train()
|
| 1034 |
|
| 1035 |
+
yield update_logs_fn("Agregando adaptadores LoRA al modelo...", "Text-to-Image (LoRA)")
|
| 1036 |
unet_lora_config = LoraConfig(
|
| 1037 |
r=int(kwargs.get('lora_r', 16)), lora_alpha=int(kwargs.get('lora_alpha', 32)),
|
| 1038 |
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
| 1039 |
)
|
| 1040 |
unet.add_adapter(unet_lora_config)
|
| 1041 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1042 |
yield update_logs_fn("Procesando dataset de im谩genes...", "Text-to-Image (LoRA)")
|
| 1043 |
resolution = int(kwargs.get('diffusion_resolution', 512))
|
| 1044 |
|
|
|
|
| 1050 |
])
|
| 1051 |
|
| 1052 |
def preprocess_train(examples):
|
| 1053 |
+
images = [Image.open(image).convert("RGB") for image in examples[kwargs.get('image_col', 'image')]]
|
| 1054 |
examples["pixel_values"] = [train_transforms(image) for image in images]
|
| 1055 |
examples["input_ids"] = tokenizer(examples[kwargs.get('text_col', 'text')], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
|
| 1056 |
return examples
|
|
|
|
| 1064 |
|
| 1065 |
def collate_fn(examples):
|
| 1066 |
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 1067 |
+
input_ids = torch.stack([example["input_ids"] for example in examples])
|
| 1068 |
return {"pixel_values": pixel_values, "input_ids": input_ids}
|
| 1069 |
|
| 1070 |
train_dataloader = DataLoader(processed_dataset, shuffle=True, collate_fn=collate_fn, batch_size=int(kwargs.get('batch_size', 1)))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1071 |
|
| 1072 |
+
yield update_logs_fn("Configurando optimizador y planificador...", "Text-to-Image (LoRA)")
|
| 1073 |
optimizer = torch.optim.AdamW(
|
| 1074 |
+
unet.parameters(), lr=float(kwargs.get('learning_rate', 2e-5)),
|
| 1075 |
betas=(float(kwargs.get('adam_beta1', 0.9)), float(kwargs.get('adam_beta2', 0.999))),
|
| 1076 |
weight_decay=float(kwargs.get('weight_decay', 0.01)),
|
| 1077 |
eps=float(kwargs.get('adam_epsilon', 1e-8)),
|
|
|
|
| 1087 |
num_training_steps=max_train_steps,
|
| 1088 |
)
|
| 1089 |
|
| 1090 |
+
unet, optimizer, train_dataloader, lr_scheduler, text_encoder, vae = accelerator.prepare(
|
| 1091 |
+
unet, optimizer, train_dataloader, lr_scheduler, text_encoder, vae
|
| 1092 |
)
|
| 1093 |
|
| 1094 |
+
text_encoder.to(accelerator.device, dtype=torch_dtype_auto)
|
| 1095 |
vae.to(accelerator.device, dtype=torch_dtype_auto)
|
| 1096 |
|
| 1097 |
+
yield update_logs_fn("Iniciando bucle de entrenamiento de difusi贸n...", "Text-to-Image (LoRA)")
|
| 1098 |
global_step = 0
|
| 1099 |
final_loss = 0
|
| 1100 |
for epoch in range(num_epochs):
|
| 1101 |
for step, batch in enumerate(train_dataloader):
|
| 1102 |
with accelerator.accumulate(unet):
|
| 1103 |
+
latents = vae.encode(batch["pixel_values"].to(torch_dtype_auto)).latent_dist.sample()
|
| 1104 |
latents = latents * vae.config.scaling_factor
|
| 1105 |
noise = torch.randn_like(latents)
|
| 1106 |
bsz = latents.shape[0]
|
| 1107 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
| 1108 |
+
timesteps = timesteps.long()
|
| 1109 |
+
|
| 1110 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 1111 |
+
encoder_hidden_states = text_encoder(batch["input_ids"].to(accelerator.device))[0]
|
| 1112 |
+
|
| 1113 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 1114 |
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
| 1115 |
final_loss = loss.detach().item()
|
| 1116 |
|
| 1117 |
accelerator.backward(loss)
|
| 1118 |
if accelerator.sync_gradients:
|
| 1119 |
+
accelerator.clip_grad_norm_(unet.parameters(), float(kwargs.get('max_grad_norm', 1.0)))
|
|
|
|
|
|
|
|
|
|
| 1120 |
|
| 1121 |
optimizer.step()
|
| 1122 |
lr_scheduler.step()
|
|
|
|
| 1124 |
|
| 1125 |
if accelerator.is_main_process:
|
| 1126 |
if global_step % int(kwargs.get('logging_steps', 10)) == 0:
|
| 1127 |
+
yield update_logs_fn(f"Epoch {epoch}, Step {step}, Loss: {final_loss}", "Text-to-Image (LoRA)")
|
| 1128 |
global_step += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1129 |
|
| 1130 |
+
yield update_logs_fn("Entrenamiento completado, guardando modelo...", "Text-to-Image (LoRA)")
|
| 1131 |
accelerator.wait_for_everyone()
|
| 1132 |
if accelerator.is_main_process:
|
| 1133 |
+
unwrapped_unet = accelerator.unwrap_model(unet)
|
| 1134 |
+
|
| 1135 |
+
pipeline = StableDiffusionText2ImagePipeline.from_pretrained(model_name, torch_dtype=torch_dtype_auto)
|
| 1136 |
+
pipeline.unet.load_state_dict(unwrapped_unet.state_dict())
|
|
|
|
|
|
|
| 1137 |
pipeline.save_pretrained(output_dir)
|
| 1138 |
|
| 1139 |
with open(os.path.join(output_dir, "README.md"), "w", encoding="utf-8") as f:
|
|
|
|
| 1148 |
torch.cuda.empty_cache()
|
| 1149 |
return output_dir, {"final_loss": final_loss}
|
| 1150 |
|
| 1151 |
+
|
| 1152 |
def train_dreambooth_lora(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 1153 |
if device == 'cpu':
|
| 1154 |
raise ValueError("El entrenamiento de DreamBooth solo es compatible con GPU CUDA.")
|
|
|
|
| 1163 |
|
| 1164 |
train_dataset = train_dataset.map(add_prompt)
|
| 1165 |
|
| 1166 |
+
yield update_logs_fn(f"Usando el prompt de instancia para todas las im谩genes: '{dreambooth_prompt}'", "DreamBooth LoRA (Text-to-Image)")
|
| 1167 |
|
| 1168 |
final_model_path, final_metrics = yield from train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs)
|
| 1169 |
return final_model_path, final_metrics
|
|
|
|
| 1194 |
if train_dataset is None:
|
| 1195 |
train_dataset = hf_train_dataset
|
| 1196 |
else:
|
| 1197 |
+
from datasets import interleave_datasets
|
| 1198 |
all_streams = [train_dataset, hf_train_dataset]
|
| 1199 |
+
all_probs = [0.5, 0.5] if not probabilities else [probabilities] + probabilities[1:]
|
| 1200 |
train_dataset = interleave_datasets(all_streams, probabilities=all_probs)
|
| 1201 |
|
| 1202 |
if train_dataset is None:
|
|
|
|
| 1206 |
text_col, image_col, audio_col, label_col = _guess_columns(first_example)
|
| 1207 |
kwargs.update({'text_col': text_col, 'image_col': image_col, 'audio_col': audio_col, 'label_col': label_col, 'uploaded_val_data': uploaded_val_data})
|
| 1208 |
|
| 1209 |
+
if kwargs['training_mode'] not in ["DreamBooth LoRA (Text-to-Image)", "Text-to-Image (LoRA)"]:
|
|
|
|
| 1210 |
if any([kwargs.get('remove_html_tags'), kwargs.get('normalize_whitespace'), kwargs.get('remove_urls_emails'), kwargs.get('redact_pii')]):
|
| 1211 |
clean_kwargs = {k:v for k,v in kwargs.items() if k in ['remove_html_tags', 'normalize_whitespace', 'remove_urls_emails', 'redact_pii']}
|
| 1212 |
train_dataset = train_dataset.map(lambda ex: _clean_text(ex, text_col, **clean_kwargs))
|
|
|
|
| 1229 |
|
| 1230 |
dedup_method = kwargs.get('deduplication_method')
|
| 1231 |
if dedup_method != 'Ninguna':
|
| 1232 |
+
base_iterator = train_dataset
|
| 1233 |
+
if dedup_method == 'Exacta':
|
| 1234 |
+
def dedup_generator_exact():
|
| 1235 |
+
seen_texts = set()
|
| 1236 |
+
for example in base_iterator:
|
| 1237 |
+
text = example.get(text_col, "")
|
| 1238 |
+
if not isinstance(text, str) or text not in seen_texts:
|
| 1239 |
+
if isinstance(text, str) and text:
|
| 1240 |
+
seen_texts.add(text)
|
| 1241 |
+
yield example
|
| 1242 |
+
train_dataset = IterableDataset.from_generator(dedup_generator_exact)
|
| 1243 |
+
elif dedup_method == 'Sem谩ntica (MinHash)':
|
| 1244 |
+
threshold = kwargs.get('minhash_threshold', 0.85)
|
| 1245 |
+
num_perm = int(kwargs.get('minhash_num_perm', 128))
|
| 1246 |
+
def dedup_generator_minhash():
|
| 1247 |
+
lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
|
| 1248 |
+
for i, example in enumerate(base_iterator):
|
| 1249 |
+
text = example.get(text_col, "")
|
| 1250 |
+
if text and isinstance(text, str) and text.strip():
|
| 1251 |
+
m = MinHash(num_perm=num_perm)
|
| 1252 |
+
for d in text.split():
|
| 1253 |
+
m.update(d.encode('utf8'))
|
| 1254 |
+
if not lsh.query(m):
|
| 1255 |
+
lsh.insert(f"key_{i}", m)
|
| 1256 |
+
yield example
|
| 1257 |
+
else:
|
| 1258 |
+
yield example
|
| 1259 |
+
train_dataset = IterableDataset.from_generator(dedup_generator_minhash)
|
| 1260 |
+
|
| 1261 |
|
| 1262 |
return train_dataset, kwargs
|
| 1263 |
|
|
|
|
| 1350 |
raise Exception(f"No se pudo cargar el tokenizer base '{tokenizer_id}' para el modelo desde cero: {e}")
|
| 1351 |
base_model_id_for_training = temp_model_dir
|
| 1352 |
kwargs["peft"] = False
|
| 1353 |
+
kwargs["merge_adapter"] = False
|
| 1354 |
kwargs['tokenizer_name'] = temp_model_dir
|
| 1355 |
yield update_logs(f"Modelo {architecture} inicializado en {temp_model_dir}.", "Modelo Cero") + (gr.update(), gr.update())
|
| 1356 |
|
|
|
|
| 1363 |
os.environ["WANDB_PROJECT"] = kwargs.get('wandb_project_input') or f"{repo_base}"
|
| 1364 |
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
|
| 1365 |
|
| 1366 |
+
from datetime import datetime
|
| 1367 |
model_card_content = MODEL_CARD_TEMPLATE.format(
|
| 1368 |
repo_id=repo_id, base_model=model_name, base_model_name=model_name.split('/')[-1],
|
| 1369 |
training_mode=kwargs.get('training_mode'),
|
|
|
|
| 1390 |
train_generator = train_func(base_model_id_for_training, train_dataset, repo_id, update_logs, model_card_content, **kwargs)
|
| 1391 |
while True:
|
| 1392 |
try:
|
| 1393 |
+
update_tuple = next(train_generator)
|
| 1394 |
+
yield update_tuple + (gr.update(), gr.update())
|
|
|
|
|
|
|
|
|
|
| 1395 |
except StopIteration as e:
|
| 1396 |
final_model_path, final_metrics = e.value
|
| 1397 |
break
|
|
|
|
| 1405 |
eval_dataset_perp = None
|
| 1406 |
eval_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), lambda m, p: update_logs(m, p))
|
| 1407 |
for update in eval_gen:
|
| 1408 |
+
if isinstance(update, tuple):
|
| 1409 |
yield update + (gr.update(), gr.update())
|
| 1410 |
else:
|
| 1411 |
eval_dataset_perp = update
|
|
|
|
| 1525 |
for item in all_data:
|
| 1526 |
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
| 1527 |
|
| 1528 |
+
from datetime import datetime
|
| 1529 |
readme_content = DATASET_CARD_TEMPLATE.format(
|
| 1530 |
repo_id=repo_id,
|
| 1531 |
creation_type=creation_type,
|
|
|
|
| 1563 |
dataset, processed_kwargs = _get_data_processing_pipeline(**kwargs)
|
| 1564 |
text_col = processed_kwargs.get('text_col')
|
| 1565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1566 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 1567 |
+
kwargs.get('tokenizer_name') or kwargs.get('model_base_input') or 'gpt2',
|
| 1568 |
+
trust_remote_code=True, use_fast=False
|
| 1569 |
)
|
| 1570 |
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
| 1571 |
if kwargs.get('chat_template_jinja', '').strip(): tokenizer.chat_template = kwargs['chat_template_jinja']
|
|
|
|
| 1574 |
for i, example in enumerate(islice(dataset, 5)):
|
| 1575 |
formatted_text = ""
|
| 1576 |
if kwargs['training_mode'] == "DPO (Direct Preference Optimization)":
|
| 1577 |
+
formatted_text = json.dumps(_dpo_formatting_func(example, **kwargs), indent=2)
|
| 1578 |
else:
|
| 1579 |
formatted_text = _sft_formatting_func(example, text_col, tokenizer, **kwargs)
|
| 1580 |
|
| 1581 |
preview_samples.append(f"--- MUESTRA {i+1} ---\n{formatted_text}\n")
|
| 1582 |
|
| 1583 |
preview_text = "\n".join(preview_samples)
|
| 1584 |
+
if not preview_text:
|
| 1585 |
+
preview_text = "No se pudieron generar muestras. Revisa la configuraci贸n del dataset y el formato."
|
| 1586 |
yield preview_text
|
| 1587 |
|
| 1588 |
except Exception as e:
|
|
|
|
| 1604 |
is_sft = "Causal" in training_mode
|
| 1605 |
is_ner = "Token Classification" in training_mode
|
| 1606 |
is_diffusion = training_mode in ["Text-to-Image (LoRA)", "DreamBooth LoRA (Text-to-Image)"]
|
|
|
|
| 1607 |
|
| 1608 |
return (
|
| 1609 |
gr.update(visible=is_classification or is_ner),
|
|
|
|
| 1613 |
gr.update(visible=training_mode == "DreamBooth LoRA (Text-to-Image)"),
|
| 1614 |
gr.update(visible=not is_diffusion),
|
| 1615 |
gr.update(visible=is_diffusion),
|
| 1616 |
+
gr.update(visible=not is_diffusion),
|
|
|
|
| 1617 |
)
|
| 1618 |
|
| 1619 |
+
|
| 1620 |
def toggle_auto_modules_ui(is_auto):
|
| 1621 |
return gr.update(visible=not is_auto)
|
| 1622 |
|
|
|
|
| 1650 |
dset_file_uploads = gr.File(label="Subir Archivos (.jsonl, .csv, .txt)", file_count="multiple")
|
| 1651 |
dset_create_button = gr.Button("Crear y Subir Dataset", variant="primary")
|
| 1652 |
with gr.Column(scale=2):
|
| 1653 |
+
dset_status_output = gr.Textbox(label="Estado", lines=10)
|
| 1654 |
dset_link_output = gr.Markdown()
|
| 1655 |
|
| 1656 |
dset_creation_type.change(toggle_dataset_creator_ui, inputs=[dset_creation_type], outputs=[dset_synth_group, dset_file_group])
|
|
|
|
| 1706 |
with gr.Accordion("Avanzados", open=False):
|
| 1707 |
warmup_ratio = gr.Slider(0.0, 0.5, 0.03, label="Ratio de Calentamiento")
|
| 1708 |
weight_decay = gr.Textbox(label="Decaimiento de Peso", value="0.01")
|
| 1709 |
+
max_grad_norm = gr.Textbox(label="Norma M谩xima de Gradiente", value="0.3")
|
| 1710 |
logging_steps = gr.Textbox(label="Pasos de Registro", value="10")
|
| 1711 |
save_steps = gr.Textbox(label="Pasos de Guardado", value="50")
|
| 1712 |
save_total_limit = gr.Textbox(label="L铆mite Total de Guardado", value="1")
|
|
|
|
| 1766 |
diffusion_resolution = gr.Slider(256, 1024, 512, step=64, label="Resoluci贸n")
|
| 1767 |
with gr.Group(visible=False) as dreambooth_ui:
|
| 1768 |
dreambooth_instance_prompt = gr.Textbox(label="Prompt de Instancia", placeholder="p.ej. 'foto de perro sks'")
|
| 1769 |
+
dreambooth_class_prompt = gr.Textbox(label="Prompt de Clase (Opcional)", placeholder="p.ej. 'foto de perro'")
|
| 1770 |
+
dreambooth_num_class_images = gr.Slider(0, 1000, 100, step=10, label="N潞 de Im谩genes de Clase")
|
| 1771 |
+
dreambooth_prior_loss_weight = gr.Slider(0.0, 2.0, 1.0, label="Peso de P茅rdida a Priori")
|
| 1772 |
dreambooth_train_text_encoder = gr.Checkbox(label="Entrenar Text Encoder", value=True)
|
| 1773 |
with gr.Group(visible=False) as classification_labels_ui:
|
| 1774 |
classification_labels = gr.Textbox(label="Etiquetas de Clasificaci贸n (csv)", placeholder="p.ej. positivo,negativo")
|
|
|
|
| 1787 |
enable_cda = gr.Checkbox(label="Habilitar Aumentaci贸n Contrafactual (CDA)", value=False)
|
| 1788 |
cda_json_config = gr.Textbox(label="Configuraci贸n CDA (JSON)", placeholder='[["ella", "茅l"], ["mujer", "hombre"]]')
|
| 1789 |
|
| 1790 |
+
|
| 1791 |
with gr.Accordion("馃攲 Integraciones", open=False):
|
| 1792 |
wandb_api_key_input = gr.Textbox(label="Clave API de W&B", type="password")
|
| 1793 |
wandb_project_input = gr.Textbox(label="Proyecto W&B")
|
|
|
|
| 1832 |
"diffusion_resolution": diffusion_resolution, "run_evaluation": run_evaluation, "run_perplexity_evaluation": run_perplexity_evaluation,
|
| 1833 |
"enable_loss_reweighting": enable_loss_reweighting, "reweighting_terms": reweighting_terms,
|
| 1834 |
"wandb_api_key_input": wandb_api_key_input, "wandb_project_input": wandb_project_input,
|
| 1835 |
+
"dreambooth_instance_prompt": dreambooth_instance_prompt, "dreambooth_class_prompt": dreambooth_class_prompt,
|
| 1836 |
+
"dreambooth_num_class_images": dreambooth_num_class_images, "dreambooth_prior_loss_weight": dreambooth_prior_loss_weight,
|
| 1837 |
"dreambooth_train_text_encoder": dreambooth_train_text_encoder
|
| 1838 |
}
|
| 1839 |
|
|
|
|
| 1905 |
outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
|
| 1906 |
)
|
| 1907 |
|
| 1908 |
+
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|