|
|
import os |
|
|
|
|
|
import torch |
|
|
import logging |
|
|
import multiprocessing |
|
|
import threading |
|
|
from itertools import chain |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
from datasets import load_dataset, get_dataset_config_names, IterableDataset |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback |
|
|
from peft import LoraConfig, get_peft_model, PeftModel |
|
|
from huggingface_hub import login, whoami, create_repo, upload_folder |
|
|
from IPython.display import clear_output |
|
|
import gradio as gr |
|
|
from dotenv import load_dotenv |
|
|
import spaces |
|
|
|
|
|
try: |
|
|
load_dotenv() |
|
|
except: |
|
|
pass |
|
|
|
|
|
@spaces.GPU |
|
|
class GradioProgressCallback(TrainerCallback): |
|
|
def __init__(self, progress_bar): |
|
|
self.progress_bar = progress_bar |
|
|
|
|
|
def on_step_end(self, args, state, control, **kwargs): |
|
|
if state.global_step > 0: |
|
|
self.progress_bar(state.global_step / state.max_steps, desc=f"Paso {state.global_step}/{state.max_steps}") |
|
|
return control |
|
|
|
|
|
@spaces.GPU() |
|
|
def run_training(hf_token, model_name, new_repo_name, lora_r, lora_alpha, lora_dropout, |
|
|
train_steps, learning_rate, batch_size, datasets_text, progress=gr.Progress()): |
|
|
|
|
|
os.environ["WANDB_DISABLED"] = "true" |
|
|
os.environ["HF_TOKEN"] = hf_token |
|
|
|
|
|
try: |
|
|
login(token=hf_token) |
|
|
username = whoami()["name"] |
|
|
except Exception as e: |
|
|
return f"Error de autenticación: {str(e)}" |
|
|
|
|
|
|
|
|
num_workers = multiprocessing.cpu_count() |
|
|
|
|
|
if not hasattr(torch, 'xla'): |
|
|
class DummyXLA: |
|
|
def __getattr__(self, name): |
|
|
return lambda *args, **kwargs: None |
|
|
torch.xla = DummyXLA() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
raw_items = datasets_text.replace('\n', ',').split(',') |
|
|
dataset_list = [item.strip() for item in raw_items if item.strip()] |
|
|
|
|
|
def get_sample_text(ds): |
|
|
try: |
|
|
sample = next(iter(ds)) |
|
|
if isinstance(sample, dict): |
|
|
return sample.get("text", str(sample)) |
|
|
return str(sample) |
|
|
except: |
|
|
return None |
|
|
|
|
|
def load_single(ds_name, cfg): |
|
|
try: |
|
|
ds = load_dataset(ds_name, cfg, streaming=True) |
|
|
if isinstance(ds, dict): |
|
|
ds = next(iter(ds.values())) |
|
|
|
|
|
if get_sample_text(ds): |
|
|
return ds |
|
|
return None |
|
|
except: |
|
|
return None |
|
|
|
|
|
def load_all_datasets(): |
|
|
streams = [] |
|
|
tasks = [] |
|
|
progress(0.1, desc="Analizando configuraciones...") |
|
|
|
|
|
for ds_name in dataset_list: |
|
|
try: |
|
|
configs = get_dataset_config_names(ds_name) |
|
|
except: |
|
|
configs = [] |
|
|
|
|
|
if not configs: |
|
|
tasks.append((ds_name, None)) |
|
|
else: |
|
|
for c in configs: |
|
|
tasks.append((ds_name, c)) |
|
|
|
|
|
progress(0.2, desc=f"Cargando {len(tasks)} fuentes...") |
|
|
with ThreadPoolExecutor(max_workers=num_workers) as executor: |
|
|
future_to_task = {executor.submit(load_single, d, c): (d, c) for d, c in tasks} |
|
|
for future in as_completed(future_to_task): |
|
|
try: |
|
|
ds = future.result() |
|
|
if ds: |
|
|
streams.append(ds) |
|
|
except: |
|
|
pass |
|
|
return streams |
|
|
|
|
|
loaded_streams = load_all_datasets() |
|
|
if not loaded_streams: |
|
|
return "Error: No se pudo cargar ningún dataset válido." |
|
|
|
|
|
def all_samples(): |
|
|
return chain.from_iterable(loaded_streams) |
|
|
|
|
|
progress(0.3, desc="Cargando Tokenizer...") |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", add_eos_token=True, add_bos_token=True) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
except Exception as e: |
|
|
return f"Error cargando tokenizer: {str(e)}" |
|
|
|
|
|
def create_text_lines(sample): |
|
|
if isinstance(sample, dict): |
|
|
text = sample.get("text", "\n".join(str(v) for v in sample.values() if isinstance(v, str))) |
|
|
else: |
|
|
text = str(sample) |
|
|
return [line.strip() for line in text.splitlines() if line.strip()] |
|
|
|
|
|
def process_sample(sample): |
|
|
lines = create_text_lines(sample) |
|
|
results = [] |
|
|
for line in lines: |
|
|
tok = tokenizer(line, truncation=False) |
|
|
tok["labels"] = tok["input_ids"].copy() |
|
|
results.append(tok) |
|
|
return results |
|
|
|
|
|
def processed_samples_generator(): |
|
|
batch = [] |
|
|
for sample in all_samples(): |
|
|
batch.append(sample) |
|
|
if len(batch) >= 100: |
|
|
with ThreadPoolExecutor(max_workers=num_workers) as executor: |
|
|
futures = [executor.submit(process_sample, s) for s in batch] |
|
|
for future in as_completed(futures): |
|
|
try: |
|
|
res = future.result() |
|
|
for tok in res: |
|
|
yield tok |
|
|
except: |
|
|
pass |
|
|
batch.clear() |
|
|
|
|
|
if batch: |
|
|
with ThreadPoolExecutor(max_workers=num_workers) as executor: |
|
|
futures = [executor.submit(process_sample, s) for s in batch] |
|
|
for future in as_completed(futures): |
|
|
try: |
|
|
res = future.result() |
|
|
for tok in res: |
|
|
yield tok |
|
|
except: |
|
|
pass |
|
|
|
|
|
progress(0.4, desc="Cargando Modelo...") |
|
|
try: |
|
|
original_model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
except Exception as e: |
|
|
return f"Error cargando modelo: {str(e)}" |
|
|
|
|
|
peft_config = LoraConfig( |
|
|
r=int(lora_r), |
|
|
lora_alpha=int(lora_alpha), |
|
|
target_modules=["q_proj", "k_proj", "v_proj", "dense"], |
|
|
bias="none", |
|
|
lora_dropout=lora_dropout, |
|
|
task_type="CAUSAL_LM" |
|
|
) |
|
|
|
|
|
peft_model = get_peft_model(original_model, peft_config) |
|
|
peft_model.config.use_cache = False |
|
|
|
|
|
output_dir = "/content/final-checkpoint" |
|
|
max_steps_val = int(train_steps) |
|
|
save_steps_val = max_steps_val // 2 if max_steps_val > 10 else 1 |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=output_dir, |
|
|
per_device_train_batch_size=int(batch_size), |
|
|
gradient_accumulation_steps=1, |
|
|
max_steps=max_steps_val, |
|
|
learning_rate=learning_rate, |
|
|
optim="adamw_torch", |
|
|
logging_steps=5, |
|
|
save_strategy="steps", |
|
|
save_steps=save_steps_val, |
|
|
report_to="none" |
|
|
) |
|
|
|
|
|
processed_dataset = IterableDataset.from_generator(processed_samples_generator) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=peft_model, |
|
|
train_dataset=processed_dataset, |
|
|
args=training_args, |
|
|
callbacks=[GradioProgressCallback(progress)] |
|
|
) |
|
|
|
|
|
progress(0.5, desc="Entrenando...") |
|
|
trainer.train() |
|
|
|
|
|
progress(0.8, desc="Guardando...") |
|
|
trainer.save_model(output_dir) |
|
|
|
|
|
progress(0.9, desc="Fusionando...") |
|
|
ft = PeftModel.from_pretrained(original_model, output_dir, torch_dtype=torch.float32, is_trainable=False).merge_and_unload() |
|
|
|
|
|
final_path = "/content/merged_model" |
|
|
ft.save_pretrained(final_path, safe_serialization=True) |
|
|
tokenizer.save_pretrained(final_path) |
|
|
|
|
|
progress(0.95, desc="Subiendo...") |
|
|
full_repo = f"{username}/{new_repo_name}" |
|
|
create_repo(full_repo, token=hf_token, exist_ok=True) |
|
|
upload_folder(folder_path=final_path, repo_id=full_repo, token=hf_token) |
|
|
|
|
|
return f"Completado: https://huggingface.co/{full_repo}" |
|
|
|
|
|
custom_css = """ |
|
|
body {background-color: #0b0f19; color: #e0e6ed;} |
|
|
.gradio-container {max-width: 1200px !important; margin: 0 auto;} |
|
|
h1 {text-align: center; color: #00e5ff; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; text-transform: uppercase; letter-spacing: 2px;} |
|
|
.primary-btn {background: linear-gradient(135deg, #00C9FF 0%, #92FE9D 100%); border: none; color: #000; font-weight: 800; font-size: 16px; padding: 12px; transition: transform 0.2s;} |
|
|
.primary-btn:hover {transform: scale(1.02); filter: brightness(1.1);} |
|
|
.input-box textarea {font-family: 'Consolas', 'Monaco', monospace; font-size: 13px; background-color: #1a202c; color: #a0aec0; border: 1px solid #2d3748;} |
|
|
.gr-box {border-radius: 8px; background-color: #1a202c; border: 1px solid #2d3748;} |
|
|
label {color: #00e5ff !important; font-weight: bold;} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(title="Entrenador LLM Ultimate") as demo: |
|
|
gr.HTML(f"<style>{custom_css}</style>") |
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; margin-bottom: 20px;"> |
|
|
<h1 style="margin: 0;">⚡ INFINITE LLM TRAINER ⚡</h1> |
|
|
<p style="color: #a0aec0;">Entrenamiento Multi-Dataset con Fusión Automática y Subida a Hub</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
hf_token_input = gr.Textbox(label="HuggingFace Token", type="password", placeholder="hf_...", value=os.getenv("HF_TOKEN", "")) |
|
|
model_input = gr.Textbox(label="Modelo Base", value="", placeholder="Ej: Qwen/Qwen2.5-0.5B (Requerido)") |
|
|
repo_input = gr.Textbox(label="Nombre Nuevo Repo", value="multi-dataset-model-v1") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(): |
|
|
gr.Markdown("### 🎛️ Configuración Avanzada LoRA") |
|
|
r_input = gr.Slider(minimum=8, maximum=256, value=32, step=8, label="Rank (r)") |
|
|
alpha_input = gr.Slider(minimum=8, maximum=512, value=32, step=8, label="Alpha") |
|
|
dropout_input = gr.Slider(minimum=0.0, maximum=0.5, value=0.05, step=0.01, label="Dropout") |
|
|
|
|
|
with gr.Row(): |
|
|
steps_input = gr.Number(label="Max Steps (Duración)", value=500, precision=0) |
|
|
lr_input = gr.Number(label="Learning Rate", value=2e-4) |
|
|
batch_input = gr.Number(label="Batch Size", value=1, precision=0) |
|
|
|
|
|
datasets_input = gr.Textbox(label="Fuentes de Datos (Datasets)", value="", placeholder="Pega aquí tus datasets separados por coma o salto de línea.\nEjemplo:\nSalesforce/fineweb_deduplicated\nbigcode/the-stack, v2", lines=12, elem_classes="input-box") |
|
|
|
|
|
train_btn = gr.Button("🚀 INICIAR ENTRENAMIENTO", elem_classes="primary-btn") |
|
|
status_output = gr.Textbox(label="Log del Sistema", interactive=False, lines=3) |
|
|
|
|
|
train_btn.click( |
|
|
fn=run_training, |
|
|
inputs=[hf_token_input, model_input, repo_input, r_input, alpha_input, dropout_input, |
|
|
steps_input, lr_input, batch_input, datasets_input], |
|
|
outputs=status_output |
|
|
) |
|
|
|
|
|
demo.launch(share=True, debug=True) |