Andro0s commited on
Commit
f061838
·
verified ·
1 Parent(s): 766e9ac

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -0
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from huggingface_hub import login
4
+ from datasets import load_dataset, Dataset
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline
6
+ from peft import get_peft_model, LoraConfig, TaskType, PeftModel
7
+
8
+ # ============================================================
9
+ # ⚙️ CONFIGURACIÓN GLOBAL
10
+ # ============================================================
11
+ # Modelo base para generación de código
12
+ BASE_MODEL = "bigcode/santacoder"
13
+ LORA_PATH = "./lora_output" # Directorio para guardar los adaptadores LoRA
14
+ DATASET_PATH = "tu_dataset.json" # ¡Cambia esto por el nombre de tu archivo JSON!
15
+ MAX_TOKEN_LENGTH = 256 # Longitud de secuencia uniforme (corrige errores de tamaño)
16
+
17
+ # Variables globales para acceso a recursos
18
+ tokenizer = None
19
+ lora_model = None
20
+ tokenized_dataset = None
21
+ lora_generator = None
22
+
23
+ # ============================================================
24
+ # 🔐 AUTENTICACIÓN Y PRE-CARGA DE RECURSOS (SINGLETON)
25
+ # ============================================================
26
+
27
+ def setup_resources():
28
+ """Carga y configura todos los recursos (modelo, tokenizer, dataset) una sola vez."""
29
+ global tokenizer, lora_model, tokenized_dataset
30
+
31
+ # 1. Autenticación con Hugging Face
32
+ hf_token = os.environ.get("HF_TOKEN")
33
+ if hf_token:
34
+ login(token=hf_token)
35
+ else:
36
+ print("⚠️ Token no encontrado. La app intentará correr sin autenticación de escritura.")
37
+
38
+ # 2. Carga del Tokenizer y Modelo Base
39
+ print("\n🔄 Cargando modelo y tokenizer una sola vez...")
40
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
41
+ base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
42
+
43
+ if tokenizer.pad_token is None:
44
+ tokenizer.pad_token = tokenizer.eos_token
45
+
46
+ # 3. Configuración y Aplicación LoRA (PEFT)
47
+ peft_config = LoraConfig(
48
+ task_type=TaskType.CAUSAL_LM,
49
+ r=8,
50
+ lora_alpha=32,
51
+ lora_dropout=0.1,
52
+ target_modules=["c_proj", "c_attn"], # Adaptado para modelos causales tipo GPT/Santacoder
53
+ )
54
+ # Envuelve el modelo base con la configuración LoRA
55
+ lora_model = get_peft_model(base_model, peft_config)
56
+ print(f"✅ Modelo LoRA preparado. Parámetros entrenables: {lora_model.print_trainable_parameters()}")
57
+
58
+ # 4. Carga y Tokenización del Dataset (para eficiencia)
59
+ print("📚 Cargando y tokenizando dataset (esto solo se hace una vez)...")
60
+ try:
61
+ raw_dataset = load_dataset("json", data_files=DATASET_PATH)
62
+ # Tokenización rápida con batched=True
63
+ tokenized_dataset = raw_dataset.map(
64
+ lambda e: tokenizer(
65
+ e["prompt"] + e["completion"],
66
+ truncation=True,
67
+ padding="max_length",
68
+ max_length=MAX_TOKEN_LENGTH
69
+ ),
70
+ batched=True,
71
+ remove_columns=raw_dataset["train"].column_names
72
+ )
73
+ print("✅ Dataset tokenizado correctamente.")
74
+ except Exception as e:
75
+ tokenized_dataset = None
76
+ print(f"❌ Error al cargar o tokenizar el dataset. Asegúrate que '{DATASET_PATH}' exista. {e}")
77
+
78
+
79
+ # ============================================================
80
+ # 🧩 FUNCIÓN DE ENTRENAMIENTO
81
+ # ============================================================
82
+ def train_lora(epochs, batch_size, learning_rate):
83
+ """Ejecuta el entrenamiento del modelo LoRA."""
84
+ global lora_model, tokenized_dataset, lora_generator
85
+
86
+ # Verifica si el dataset está disponible
87
+ if tokenized_dataset is None or "train" not in tokenized_dataset:
88
+ return "❌ Error: El dataset no pudo cargarse o está vacío. No se puede entrenar."
89
+
90
+ try:
91
+ # Reinicia el generador para que cargue el modelo entrenado en el siguiente test
92
+ lora_generator = None
93
+
94
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
95
+
96
+ training_args = TrainingArguments(
97
+ output_dir=LORA_PATH,
98
+ per_device_train_batch_size=int(batch_size),
99
+ num_train_epochs=float(epochs),
100
+ learning_rate=float(learning_rate),
101
+ save_total_limit=1,
102
+ logging_steps=10,
103
+ push_to_hub=False,
104
+ )
105
+
106
+ trainer = Trainer(
107
+ model=lora_model, # Usa el modelo LoRA global
108
+ args=training_args,
109
+ train_dataset=tokenized_dataset["train"],
110
+ data_collator=data_collator,
111
+ )
112
+
113
+ trainer.train()
114
+
115
+ # Guarda SOLAMENTE los adaptadores LoRA (archivos pequeños)
116
+ lora_model.save_pretrained(LORA_PATH)
117
+ tokenizer.save_pretrained(LORA_PATH)
118
+
119
+ return f"✅ Entrenamiento completado. Adaptadores LoRA guardados en **{LORA_PATH}**"
120
+ except Exception as e:
121
+ return f"❌ Error durante el entrenamiento: {e}"
122
+
123
+ # ============================================================
124
+ # 🤖 FUNCI��N DE GENERACIÓN (INFERENCIA)
125
+ # ============================================================
126
+ def generate_text(prompt_text):
127
+ """Genera texto usando el modelo base + adaptadores LoRA."""
128
+ global lora_generator
129
+
130
+ try:
131
+ # Carga el generador SOLAMENTE si no ha sido cargado aún
132
+ if lora_generator is None:
133
+ # 1. Cargar el modelo base
134
+ base_model_gen = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
135
+
136
+ # 2. Aplicar adaptadores LoRA si existen
137
+ if os.path.exists(LORA_PATH):
138
+ model_with_lora = PeftModel.from_pretrained(base_model_gen, LORA_PATH)
139
+ else:
140
+ # Si no hay LoRA, usa el modelo base (solo si está cargado globalmente)
141
+ model_with_lora = base_model_gen
142
+
143
+ # 3. Fusionar el modelo base y los adaptadores para una inferencia más rápida
144
+ final_model = model_with_lora.merge_and_unload()
145
+ final_model.eval() # Poner en modo evaluación
146
+
147
+ # 4. Inicializar el pipeline global
148
+ lora_generator = pipeline("text-generation", model=final_model, tokenizer=tokenizer)
149
+
150
+ # Generar la respuesta
151
+ output = lora_generator(prompt_text, max_new_tokens=150, temperature=0.7, top_p=0.9)
152
+ return output[0]["generated_text"]
153
+
154
+ except Exception as e:
155
+ return f"❌ Error generando texto (Asegúrate de que el modelo base/LoRA esté cargado): {e}"
156
+
157
+ # ============================================================
158
+ # 💻 INTERFAZ GRADIO
159
+ # ============================================================
160
+ with gr.Blocks(title="AmorCoderAI - LoRA") as demo:
161
+ gr.Markdown("# 💙 AmorCoderAI - Entrenamiento y Pruebas LoRA")
162
+ gr.Markdown(f"Modelo base: `{BASE_MODEL}`. Adaptadores guardados en `{LORA_PATH}`. **El auto-entrenamiento se ejecuta al iniciar.**")
163
+
164
+ with gr.Tab("🧠 Entrenar (Manual)"):
165
+ gr.Markdown("--- **¡CUIDADO!** El entrenamiento es lento y consume muchos recursos (VRAM/RAM). ---")
166
+ epochs = gr.Number(value=1, label="Épocas", precision=0)
167
+ batch_size = gr.Number(value=2, label="Tamaño de lote (ajusta según tu VRAM)", precision=0)
168
+ learning_rate = gr.Number(value=5e-5, label="Tasa de aprendizaje")
169
+ train_button = gr.Button("🚀 Iniciar Entrenamiento Manual")
170
+ train_output = gr.Textbox(label="Resultado del Entrenamiento Manual")
171
+
172
+ train_button.click(
173
+ train_lora,
174
+ inputs=[epochs, batch_size, learning_rate],
175
+ outputs=train_output
176
+ )
177
+
178
+ with gr.Tab("✨ Probar modelo"):
179
+ prompt = gr.Textbox(label="Escribe código (ej: 'def fibonacci(n):')", lines=4)
180
+ generate_button = gr.Button("💬 Generar código")
181
+ output_box = gr.Textbox(label="Salida generada", lines=10)
182
+ generate_button.click(generate_text, inputs=prompt, outputs=output_box)
183
+
184
+ # ============================================================
185
+ # 🚀 LANZAR APP Y AUTO-ENTRENAMIENTO (¡AQUÍ SUCEDE LA MAGIA!)
186
+ # ============================================================
187
+ if __name__ == "__main__":
188
+ # 1. Cargar todos los recursos globales
189
+ setup_resources()
190
+
191
+ # 2. AUTO-ENTRENAMIENTO (Se ejecuta con valores por defecto)
192
+ print("\n=============================================")
193
+ print("🤖 INICIANDO AUTO-ENTRENAMIENTO (1 Época, 2 Batch Size)")
194
+ print("=============================================")
195
+
196
+ # Parámetros por defecto para el auto-entrenamiento
197
+ auto_train_result = train_lora(epochs=1, batch_size=2, learning_rate=5e-5)
198
+
199
+ print(f"\nFIN DEL AUTO-ENTRENAMIENTO: {auto_train_result}")
200
+
201
+ # 3. Lanzar la Interfaz Gradio
202
+ print("\n=============================================")
203
+ print("💻 LANZANDO INTERFAZ GRADIO")
204
+ print("=============================================")
205
+ demo.launch()