Hh / app.py
Ksjsjjdj's picture
Update app.py
865fcd1 verified
import os
#os.system("pip install spaces-0.1.0-py3-none-any.whl")
import torch
import logging
import multiprocessing
import threading
from itertools import chain
from concurrent.futures import ThreadPoolExecutor, as_completed
from datasets import load_dataset, get_dataset_config_names, IterableDataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback
from peft import LoraConfig, get_peft_model, PeftModel
from huggingface_hub import login, whoami, create_repo, upload_folder
from IPython.display import clear_output
import gradio as gr
from dotenv import load_dotenv
import spaces
try:
load_dotenv()
except:
pass
@spaces.GPU
class GradioProgressCallback(TrainerCallback):
def __init__(self, progress_bar):
self.progress_bar = progress_bar
def on_step_end(self, args, state, control, **kwargs):
if state.global_step > 0:
self.progress_bar(state.global_step / state.max_steps, desc=f"Paso {state.global_step}/{state.max_steps}")
return control
@spaces.GPU()
def run_training(hf_token, model_name, new_repo_name, lora_r, lora_alpha, lora_dropout,
train_steps, learning_rate, batch_size, datasets_text, progress=gr.Progress()):
os.environ["WANDB_DISABLED"] = "true"
os.environ["HF_TOKEN"] = hf_token
try:
login(token=hf_token)
username = whoami()["name"]
except Exception as e:
return f"Error de autenticación: {str(e)}"
# device = "cuda" if torch.cuda.is_available() else "cpu"
num_workers = multiprocessing.cpu_count()
if not hasattr(torch, 'xla'):
class DummyXLA:
def __getattr__(self, name):
return lambda *args, **kwargs: None
torch.xla = DummyXLA()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
raw_items = datasets_text.replace('\n', ',').split(',')
dataset_list = [item.strip() for item in raw_items if item.strip()]
def get_sample_text(ds):
try:
sample = next(iter(ds))
if isinstance(sample, dict):
return sample.get("text", str(sample))
return str(sample)
except:
return None
def load_single(ds_name, cfg):
try:
ds = load_dataset(ds_name, cfg, streaming=True)
if isinstance(ds, dict):
ds = next(iter(ds.values()))
if get_sample_text(ds):
return ds
return None
except:
return None
def load_all_datasets():
streams = []
tasks = []
progress(0.1, desc="Analizando configuraciones...")
for ds_name in dataset_list:
try:
configs = get_dataset_config_names(ds_name)
except:
configs = []
if not configs:
tasks.append((ds_name, None))
else:
for c in configs:
tasks.append((ds_name, c))
progress(0.2, desc=f"Cargando {len(tasks)} fuentes...")
with ThreadPoolExecutor(max_workers=num_workers) as executor:
future_to_task = {executor.submit(load_single, d, c): (d, c) for d, c in tasks}
for future in as_completed(future_to_task):
try:
ds = future.result()
if ds:
streams.append(ds)
except:
pass
return streams
loaded_streams = load_all_datasets()
if not loaded_streams:
return "Error: No se pudo cargar ningún dataset válido."
def all_samples():
return chain.from_iterable(loaded_streams)
progress(0.3, desc="Cargando Tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", add_eos_token=True, add_bos_token=True)
tokenizer.pad_token = tokenizer.eos_token
except Exception as e:
return f"Error cargando tokenizer: {str(e)}"
def create_text_lines(sample):
if isinstance(sample, dict):
text = sample.get("text", "\n".join(str(v) for v in sample.values() if isinstance(v, str)))
else:
text = str(sample)
return [line.strip() for line in text.splitlines() if line.strip()]
def process_sample(sample):
lines = create_text_lines(sample)
results = []
for line in lines:
tok = tokenizer(line, truncation=False)
tok["labels"] = tok["input_ids"].copy()
results.append(tok)
return results
def processed_samples_generator():
batch = []
for sample in all_samples():
batch.append(sample)
if len(batch) >= 100:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_sample, s) for s in batch]
for future in as_completed(futures):
try:
res = future.result()
for tok in res:
yield tok
except:
pass
batch.clear()
if batch:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_sample, s) for s in batch]
for future in as_completed(futures):
try:
res = future.result()
for tok in res:
yield tok
except:
pass
progress(0.4, desc="Cargando Modelo...")
try:
original_model = AutoModelForCausalLM.from_pretrained(model_name)
except Exception as e:
return f"Error cargando modelo: {str(e)}"
peft_config = LoraConfig(
r=int(lora_r),
lora_alpha=int(lora_alpha),
target_modules=["q_proj", "k_proj", "v_proj", "dense"],
bias="none",
lora_dropout=lora_dropout,
task_type="CAUSAL_LM"
)
peft_model = get_peft_model(original_model, peft_config)
peft_model.config.use_cache = False
output_dir = "/content/final-checkpoint"
max_steps_val = int(train_steps)
save_steps_val = max_steps_val // 2 if max_steps_val > 10 else 1
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=int(batch_size),
gradient_accumulation_steps=1,
max_steps=max_steps_val,
learning_rate=learning_rate,
optim="adamw_torch",
logging_steps=5,
save_strategy="steps",
save_steps=save_steps_val,
report_to="none"
)
processed_dataset = IterableDataset.from_generator(processed_samples_generator)
trainer = Trainer(
model=peft_model,
train_dataset=processed_dataset,
args=training_args,
callbacks=[GradioProgressCallback(progress)]
)
progress(0.5, desc="Entrenando...")
trainer.train()
progress(0.8, desc="Guardando...")
trainer.save_model(output_dir)
progress(0.9, desc="Fusionando...")
ft = PeftModel.from_pretrained(original_model, output_dir, torch_dtype=torch.float32, is_trainable=False).merge_and_unload()
final_path = "/content/merged_model"
ft.save_pretrained(final_path, safe_serialization=True)
tokenizer.save_pretrained(final_path)
progress(0.95, desc="Subiendo...")
full_repo = f"{username}/{new_repo_name}"
create_repo(full_repo, token=hf_token, exist_ok=True)
upload_folder(folder_path=final_path, repo_id=full_repo, token=hf_token)
return f"Completado: https://huggingface.co/{full_repo}"
custom_css = """
body {background-color: #0b0f19; color: #e0e6ed;}
.gradio-container {max-width: 1200px !important; margin: 0 auto;}
h1 {text-align: center; color: #00e5ff; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; text-transform: uppercase; letter-spacing: 2px;}
.primary-btn {background: linear-gradient(135deg, #00C9FF 0%, #92FE9D 100%); border: none; color: #000; font-weight: 800; font-size: 16px; padding: 12px; transition: transform 0.2s;}
.primary-btn:hover {transform: scale(1.02); filter: brightness(1.1);}
.input-box textarea {font-family: 'Consolas', 'Monaco', monospace; font-size: 13px; background-color: #1a202c; color: #a0aec0; border: 1px solid #2d3748;}
.gr-box {border-radius: 8px; background-color: #1a202c; border: 1px solid #2d3748;}
label {color: #00e5ff !important; font-weight: bold;}
"""
with gr.Blocks(title="Entrenador LLM Ultimate") as demo:
gr.HTML(f"<style>{custom_css}</style>")
gr.HTML("""
<div style="text-align: center; margin-bottom: 20px;">
<h1 style="margin: 0;">⚡ INFINITE LLM TRAINER ⚡</h1>
<p style="color: #a0aec0;">Entrenamiento Multi-Dataset con Fusión Automática y Subida a Hub</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
hf_token_input = gr.Textbox(label="HuggingFace Token", type="password", placeholder="hf_...", value=os.getenv("HF_TOKEN", ""))
model_input = gr.Textbox(label="Modelo Base", value="", placeholder="Ej: Qwen/Qwen2.5-0.5B (Requerido)")
repo_input = gr.Textbox(label="Nombre Nuevo Repo", value="multi-dataset-model-v1")
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### 🎛️ Configuración Avanzada LoRA")
r_input = gr.Slider(minimum=8, maximum=256, value=32, step=8, label="Rank (r)")
alpha_input = gr.Slider(minimum=8, maximum=512, value=32, step=8, label="Alpha")
dropout_input = gr.Slider(minimum=0.0, maximum=0.5, value=0.05, step=0.01, label="Dropout")
with gr.Row():
steps_input = gr.Number(label="Max Steps (Duración)", value=500, precision=0)
lr_input = gr.Number(label="Learning Rate", value=2e-4)
batch_input = gr.Number(label="Batch Size", value=1, precision=0)
datasets_input = gr.Textbox(label="Fuentes de Datos (Datasets)", value="", placeholder="Pega aquí tus datasets separados por coma o salto de línea.\nEjemplo:\nSalesforce/fineweb_deduplicated\nbigcode/the-stack, v2", lines=12, elem_classes="input-box")
train_btn = gr.Button("🚀 INICIAR ENTRENAMIENTO", elem_classes="primary-btn")
status_output = gr.Textbox(label="Log del Sistema", interactive=False, lines=3)
train_btn.click(
fn=run_training,
inputs=[hf_token_input, model_input, repo_input, r_input, alpha_input, dropout_input,
steps_input, lr_input, batch_input, datasets_input],
outputs=status_output
)
demo.launch(share=True, debug=True)