Ksjsjjdj commited on
Commit
9599c8e
·
verified ·
1 Parent(s): cfea477

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +271 -0
app.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import logging
4
+ import multiprocessing
5
+ import threading
6
+ from itertools import chain
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from datasets import load_dataset, get_dataset_config_names, IterableDataset
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback
10
+ from peft import LoraConfig, get_peft_model, PeftModel
11
+ from huggingface_hub import login, whoami, create_repo, upload_folder
12
+ from IPython.display import clear_output
13
+ import gradio as gr
14
+ from dotenv import load_dotenv
15
+ import spaces
16
+
17
+ load_dotenv()
18
+
19
+ class GradioProgressCallback(TrainerCallback):
20
+ def __init__(self, progress_bar):
21
+ self.progress_bar = progress_bar
22
+
23
+ def on_step_end(self, args, state, control, **kwargs):
24
+ if state.global_step > 0:
25
+ self.progress_bar(state.global_step / state.max_steps, desc=f"Paso {state.global_step}/{state.max_steps}")
26
+ return control
27
+
28
+ @spaces.GPU()
29
+ def run_training(hf_token, model_name, new_repo_name, lora_r, lora_alpha, lora_dropout,
30
+ train_steps, learning_rate, batch_size, datasets_text, progress=gr.Progress()):
31
+
32
+ os.environ["WANDB_DISABLED"] = "true"
33
+ os.environ["HF_TOKEN"] = hf_token
34
+
35
+ try:
36
+ login(token=hf_token)
37
+ username = whoami()["name"]
38
+ except Exception as e:
39
+ return f"Error de autenticación: {str(e)}"
40
+
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ num_workers = multiprocessing.cpu_count()
43
+
44
+ if not hasattr(torch, 'xla'):
45
+ class DummyXLA:
46
+ def __getattr__(self, name):
47
+ return lambda *args, **kwargs: None
48
+ torch.xla = DummyXLA()
49
+
50
+ logging.basicConfig(level=logging.INFO)
51
+ logger = logging.getLogger(__name__)
52
+
53
+ raw_items = datasets_text.replace('\n', ',').split(',')
54
+ dataset_list = [item.strip() for item in raw_items if item.strip()]
55
+
56
+ def get_sample_text(ds):
57
+ try:
58
+ sample = next(iter(ds))
59
+ if isinstance(sample, dict):
60
+ return sample.get("text", str(sample))
61
+ return str(sample)
62
+ except:
63
+ return None
64
+
65
+ def load_single(ds_name, cfg):
66
+ try:
67
+ ds = load_dataset(ds_name, cfg, streaming=True, trust_remote_code=True)
68
+ if isinstance(ds, dict):
69
+ ds = next(iter(ds.values()))
70
+
71
+ if get_sample_text(ds):
72
+ return ds
73
+ return None
74
+ except:
75
+ return None
76
+
77
+ def load_all_datasets():
78
+ streams = []
79
+ tasks = []
80
+ progress(0.1, desc="Analizando configuraciones de datasets...")
81
+
82
+ for ds_name in dataset_list:
83
+ try:
84
+ configs = get_dataset_config_names(ds_name)
85
+ except:
86
+ configs = []
87
+
88
+ if not configs:
89
+ tasks.append((ds_name, None))
90
+ else:
91
+ for c in configs:
92
+ tasks.append((ds_name, c))
93
+
94
+ progress(0.2, desc=f"Cargando {len(tasks)} fuentes de datos...")
95
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
96
+ future_to_task = {executor.submit(load_single, d, c): (d, c) for d, c in tasks}
97
+ for future in as_completed(future_to_task):
98
+ try:
99
+ ds = future.result()
100
+ if ds:
101
+ streams.append(ds)
102
+ except:
103
+ pass
104
+ return streams
105
+
106
+ loaded_streams = load_all_datasets()
107
+ if not loaded_streams:
108
+ return "Error Crítico: No se pudo cargar ningún dataset válido."
109
+
110
+ def all_samples():
111
+ return chain.from_iterable(loaded_streams)
112
+
113
+ progress(0.3, desc="Cargando Tokenizer...")
114
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left", add_eos_token=True, add_bos_token=True)
115
+ tokenizer.pad_token = tokenizer.eos_token
116
+
117
+ def create_text_lines(sample):
118
+ if isinstance(sample, dict):
119
+ text = sample.get("text", "\n".join(str(v) for v in sample.values() if isinstance(v, str)))
120
+ else:
121
+ text = str(sample)
122
+ return [line.strip() for line in text.splitlines() if line.strip()]
123
+
124
+ def process_sample(sample):
125
+ lines = create_text_lines(sample)
126
+ results = []
127
+ for line in lines:
128
+ tok = tokenizer(line, truncation=False)
129
+ tok["labels"] = tok["input_ids"].copy()
130
+ results.append(tok)
131
+ return results
132
+
133
+ def processed_samples_generator():
134
+ batch = []
135
+ for sample in all_samples():
136
+ batch.append(sample)
137
+ if len(batch) >= 100:
138
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
139
+ futures = [executor.submit(process_sample, s) for s in batch]
140
+ for future in as_completed(futures):
141
+ try:
142
+ res = future.result()
143
+ for tok in res:
144
+ yield tok
145
+ except:
146
+ pass
147
+ batch.clear()
148
+
149
+ if batch:
150
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
151
+ futures = [executor.submit(process_sample, s) for s in batch]
152
+ for future in as_completed(futures):
153
+ try:
154
+ res = future.result()
155
+ for tok in res:
156
+ yield tok
157
+ except:
158
+ pass
159
+
160
+ progress(0.4, desc="Cargando Modelo Base...")
161
+ original_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device)
162
+
163
+ peft_config = LoraConfig(
164
+ r=int(lora_r),
165
+ lora_alpha=int(lora_alpha),
166
+ target_modules=["q_proj", "k_proj", "v_proj", "dense"],
167
+ bias="none",
168
+ lora_dropout=lora_dropout,
169
+ task_type="CAUSAL_LM"
170
+ )
171
+
172
+ peft_model = get_peft_model(original_model, peft_config)
173
+ peft_model.config.use_cache = False
174
+
175
+ output_dir = "/content/final-checkpoint"
176
+ max_steps_val = int(train_steps)
177
+ save_steps_val = max_steps_val // 2 if max_steps_val > 10 else 1
178
+
179
+ training_args = TrainingArguments(
180
+ output_dir=output_dir,
181
+ per_device_train_batch_size=int(batch_size),
182
+ gradient_accumulation_steps=1,
183
+ max_steps=max_steps_val,
184
+ learning_rate=learning_rate,
185
+ optim="adamw_torch",
186
+ logging_steps=5,
187
+ save_strategy="steps",
188
+ save_steps=save_steps_val,
189
+ report_to="none",
190
+ fp16=torch.cuda.is_available()
191
+ )
192
+
193
+ processed_dataset = IterableDataset.from_generator(processed_samples_generator)
194
+
195
+ trainer = Trainer(
196
+ model=peft_model,
197
+ train_dataset=processed_dataset,
198
+ args=training_args,
199
+ callbacks=[GradioProgressCallback(progress)]
200
+ )
201
+
202
+ progress(0.5, desc="Entrenando...")
203
+ trainer.train()
204
+
205
+ progress(0.8, desc="Guardando checkpoint...")
206
+ trainer.save_model(output_dir)
207
+
208
+ progress(0.9, desc="Fusionando modelo LoRA...")
209
+ ft = PeftModel.from_pretrained(original_model, output_dir, torch_dtype=torch.float32, is_trainable=False, device_map={"": device}).merge_and_unload()
210
+
211
+ final_path = "/content/merged_model"
212
+ ft.save_pretrained(final_path, safe_serialization=True)
213
+ tokenizer.save_pretrained(final_path)
214
+
215
+ progress(0.95, desc="Subiendo a HuggingFace...")
216
+ full_repo = f"{username}/{new_repo_name}"
217
+ create_repo(full_repo, token=hf_token, exist_ok=True)
218
+ upload_folder(folder_path=final_path, repo_id=full_repo, token=hf_token)
219
+
220
+ return f"¡Éxito! Modelo disponible en: https://huggingface.co/{full_repo}"
221
+
222
+ custom_css = """
223
+ body {background-color: #0b0f19; color: #e0e6ed;}
224
+ .gradio-container {max-width: 1200px !important; margin: 0 auto;}
225
+ h1 {text-align: center; color: #00e5ff; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; text-transform: uppercase; letter-spacing: 2px;}
226
+ .primary-btn {background: linear-gradient(135deg, #00C9FF 0%, #92FE9D 100%); border: none; color: #000; font-weight: 800; font-size: 16px; padding: 12px; transition: transform 0.2s;}
227
+ .primary-btn:hover {transform: scale(1.02); filter: brightness(1.1);}
228
+ .input-box textarea {font-family: 'Consolas', 'Monaco', monospace; font-size: 13px; background-color: #1a202c; color: #a0aec0; border: 1px solid #2d3748;}
229
+ .gr-box {border-radius: 8px; background-color: #1a202c; border: 1px solid #2d3748;}
230
+ label {color: #00e5ff !important; font-weight: bold;}
231
+ """
232
+
233
+ with gr.Blocks(css=custom_css, title="Entrenador LLM Ultimate") as demo:
234
+ gr.HTML("""
235
+ <div style="text-align: center; margin-bottom: 20px;">
236
+ <h1 style="margin: 0;">⚡ INFINITE LLM TRAINER ⚡</h1>
237
+ <p style="color: #a0aec0;">Entrenamiento Multi-Dataset con Fusión Automática y Subida a Hub</p>
238
+ </div>
239
+ """)
240
+
241
+ with gr.Row():
242
+ with gr.Column(scale=1):
243
+ hf_token_input = gr.Textbox(label="HuggingFace Token (Write)", type="password", placeholder="hf_...", value=os.getenv("HF_TOKEN", ""))
244
+ model_input = gr.Textbox(label="Modelo Base", value="arnir0/Tiny-LLM")
245
+ repo_input = gr.Textbox(label="Nombre Nuevo Repo", value="multi-dataset-model-v1")
246
+
247
+ with gr.Column(scale=1):
248
+ with gr.Group():
249
+ gr.Markdown("### 🎛️ Configuración Avanzada LoRA")
250
+ r_input = gr.Slider(minimum=8, maximum=256, value=32, step=8, label="Rank (r)")
251
+ alpha_input = gr.Slider(minimum=8, maximum=512, value=32, step=8, label="Alpha")
252
+ dropout_input = gr.Slider(minimum=0.0, maximum=0.5, value=0.05, step=0.01, label="Dropout")
253
+
254
+ with gr.Row():
255
+ steps_input = gr.Number(label="Max Steps (Duración)", value=500, precision=0)
256
+ lr_input = gr.Number(label="Learning Rate", value=2e-4)
257
+ batch_input = gr.Number(label="Batch Size", value=1, precision=0)
258
+
259
+ datasets_input = gr.Textbox(label="Fuentes de Datos (Datasets)", value="", placeholder="Pega aquí tus datasets separados por coma o salto de línea.\nEjemplo:\nSalesforce/fineweb_deduplicated\nbigcode/the-stack, v2", lines=12, elem_classes="input-box")
260
+
261
+ train_btn = gr.Button("🚀 INICIAR ENTRENAMIENTO GLOBAL", elem_classes="primary-btn")
262
+ status_output = gr.Textbox(label="Log del Sistema", interactive=False, lines=3)
263
+
264
+ train_btn.click(
265
+ fn=run_training,
266
+ inputs=[hf_token_input, model_input, repo_input, r_input, alpha_input, dropout_input,
267
+ steps_input, lr_input, batch_input, datasets_input],
268
+ outputs=status_output
269
+ )
270
+
271
+ demo.launch(share=True, debug=True)