Andro0s commited on
Commit
dbf0a6c
·
verified ·
1 Parent(s): 231e120

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +316 -0
app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import threading
4
+ import time
5
+ from huggingface_hub import login
6
+ from datasets import load_dataset
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline, AutoModelForSeq2SeqLM
8
+ from peft import get_peft_model, LoraConfig, TaskType, PeftModel
9
+ import json
10
+
11
+ # --- CONFIGURACIÓN DEL MODELO Y ENTRENAMIENTO ---
12
+ BASE_MODEL = "bigcode/santacoder" # Modelo base de programación
13
+ LORA_PATH = "./lora_output" # Ruta donde se guarda el modelo adaptado
14
+ DATASET_FILE = "codesearchnet_lora_dataset.json"
15
+ MAX_TOKEN_LENGTH = 256
16
+ NUM_SAMPLES_TO_PROCESS = 1000
17
+ DEFAULT_EPOCHS = 10
18
+
19
+ # Configuración del ciclo AUTÓNOMO (Inicia reentrenamiento cada 5 interacciones)
20
+ GENERATION_LIMIT_TO_TRAIN = 5
21
+ AUTONOMOUS_EPOCHS = 3
22
+
23
+ # Modelo de chat pre-entrenado en español
24
+ CHAT_MODEL_NAME = "bigscience/bloom"
25
+ chat_tokenizer = None
26
+ chat_model = None
27
+
28
+ # --- ESTADO GLOBAL Y THREADING ---
29
+ tokenizer = None
30
+ lora_model = None
31
+ tokenized_dataset = None
32
+ lora_generator = None
33
+
34
+ # Variables de estado
35
+ version_number = 1.0
36
+ is_trained = os.path.exists(LORA_PATH)
37
+ generations_since_last_train = 0
38
+ training_status_message = "Esperando la inicialización V1.0..."
39
+
40
+ # Lock para proteger las variables compartidas entre hilos (CRÍTICO para estabilidad)
41
+ global_lock = threading.Lock()
42
+
43
+ # --- LÓGICA DE PREPARACIÓN Y SETUP ---
44
+
45
+ def prepare_codesearchnet():
46
+ """Descarga y prepara el dataset inicial si no existe."""
47
+ if os.path.exists(DATASET_FILE):
48
+ return
49
+ try:
50
+ raw_csn = load_dataset('Nan-Do/code-search-net-python', split=f'train[:{NUM_SAMPLES_TO_PROCESS}]')
51
+
52
+ def format_for_lora(example):
53
+ # Formato que entrena a la IA a enlazar descripción (español) con código (inglés)
54
+ prompt_text = (
55
+ f"# Descripción: {example['docstring_summary']}\n"
56
+ f"# Completa la siguiente función:\n"
57
+ f"def {example['func_name']}("
58
+ )
59
+ completion_text = example['code']
60
+ return {"prompt": prompt_text, "completion": completion_text}
61
+
62
+ lora_dataset = raw_csn.map(format_for_lora, batched=False, remove_columns=raw_csn["train"].column_names)
63
+ lora_dataset.to_json(DATASET_FILE)
64
+ except Exception as e:
65
+ print(f"Error al cargar dataset. Usando datos mínimos. Error: {e}")
66
+ minimal_dataset = [{"prompt": "# Error de carga. Intenta de nuevo.", "completion": "pass\n"}] * 10
67
+ with open(DATASET_FILE, 'w') as f:
68
+ json.dump(minimal_dataset, f)
69
+
70
+ def setup_resources():
71
+ """Configura el tokenizer, el modelo base y el adaptador LoRA."""
72
+ global tokenizer, lora_model, tokenized_dataset, chat_tokenizer, chat_model
73
+
74
+ prepare_codesearchnet()
75
+
76
+ hf_token = os.environ.get("HF_TOKEN")
77
+ if hf_token:
78
+ login(token=hf_token)
79
+
80
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
81
+ base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
82
+
83
+ if tokenizer.pad_token is None:
84
+ tokenizer.pad_token = tokenizer.eos_token
85
+
86
+ peft_config = LoraConfig(
87
+ task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["c_proj", "c_attn"],
88
+ )
89
+ lora_model = get_peft_model(base_model, peft_config)
90
+
91
+ try:
92
+ raw_dataset = load_dataset("json", data_files=DATASET_FILE)
93
+
94
+ def tokenize_function(examples):
95
+ return tokenizer(
96
+ examples["prompt"] + examples["completion"],
97
+ truncation=True,
98
+ padding="max_length",
99
+ max_length=MAX_TOKEN_LENGTH
100
+ )
101
+
102
+ tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=raw_dataset["train"].column_names if "train" in raw_dataset else [],)
103
+ except Exception:
104
+ tokenized_dataset = None
105
+
106
+ # Configuración del modelo de chat
107
+ chat_tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_NAME)
108
+ chat_model = AutoModelForSeq2SeqLM.from_pretrained(CHAT_MODEL_NAME)
109
+
110
+ # --- FUNCIÓN DE ENTRENAMIENTO (EJECUTADA EN HILO SEPARADO) ---
111
+
112
+ def autonomous_train_lora(epochs, batch_size, learning_rate):
113
+ """Ejecuta el entrenamiento en un hilo separado para la autonomía."""
114
+ global lora_model, tokenized_dataset, lora_generator, version_number, is_trained, training_status_message
115
+
116
+ try:
117
+ with global_lock:
118
+ if tokenized_dataset is None or "train" not in tokenized_dataset:
119
+ training_status_message = "ERROR: No se puede entrenar. Dataset no disponible."
120
+ return
121
+
122
+ # 1. ACTUALIZAR VERSIÓN (Pre-incremento)
123
+ if is_trained:
124
+ version_number += 0.1
125
+ else:
126
+ version_number = 1.0
127
+
128
+ # 2. CONFIGURACIÓN E INICIO DEL ENTRENAMIENTO
129
+ training_status_message = f"🧠 ENTRENANDO V{version_number:.1f} (Epochs: {epochs})...."
130
+ print(f"\n[AUTÓNOMO] {training_status_message}")
131
+
132
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
133
+ training_args = TrainingArguments(
134
+ output_dir=LORA_PATH,
135
+ per_device_train_batch_size=int(batch_size),
136
+ num_train_epochs=float(epochs),
137
+ learning_rate=float(learning_rate),
138
+ save_total_limit=1,
139
+ logging_steps=10,
140
+ push_to_hub=False,
141
+ disable_tqdm=True,
142
+ report_to="none"
143
+ )
144
+
145
+ trainer = Trainer(model=lora_model, args=training_args, train_dataset=tokenized_dataset["train"], data_collator=data_collator)
146
+
147
+ trainer.train()
148
+ lora_model.save_pretrained(LORA_PATH)
149
+ tokenizer.save_pretrained(LORA_PATH)
150
+
151
+ # 3. Marcar como entrenado
152
+ is_trained = True
153
+ training_status_message = f"✅ ENTRENAMIENTO V{version_number:.1f} COMPLETADO. Modelo listo para Hot Swap."
154
+ print(f"[AUTÓNOMO] {training_status_message}")
155
+
156
+ except Exception as e:
157
+ training_status_message = f"ERROR CRÍTICO durante el entrenamiento autónomo: {e}"
158
+ print(f"[AUTÓNOMO] {training_status_message}")
159
+
160
+ # --- FUNCIÓN DE GENERACIÓN (CORREGIDA PARA RETORNAR 2 VALORES) ---
161
+
162
+ def generate_text(prompt_text):
163
+ """Genera código y dispara el ciclo de reentrenamiento autónomo si es necesario."""
164
+ global lora_generator, generations_since_last_train, is_trained, version_number, training_status_message
165
+
166
+ if not is_trained:
167
+ # Si el entrenamiento V1.0 no ha terminado, retorna el mensaje de error y el estado actual
168
+ return "ERROR: El modelo LoRA no ha sido entrenado. Por favor, espere mientras la IA se inicializa con el entrenamiento V1.0.", update_status()
169
+
170
+ # 1. HOT SWAP (Verifica si el modelo necesita recargarse con la nueva versión)
171
+ if lora_generator is None:
172
+ with global_lock:
173
+ try:
174
+ # Recarga el modelo solo si está vacío
175
+ base_model_gen = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
176
+ model_with_lora = PeftModel.from_pretrained(base_model_gen, LORA_PATH)
177
+ final_model = model_with_lora.merge_and_unload()
178
+ final_model.eval()
179
+ lora_generator = pipeline("text-generation", model=final_model, tokenizer=tokenizer)
180
+ print(f"[HOT SWAP] 🔄 Modelo de inferencia V{version_number:.1f} recargado y listo.")
181
+ except Exception as e:
182
+ # Si la recarga falla, retorna un error
183
+ return f"Error al cargar el modelo V{version_number:.1f} para inferencia: {e}", update_status()
184
+
185
+ # 2. Generación de texto (Lógica de inferencia)
186
+ try:
187
+ # Prepara el prompt para guiar la generación del código
188
+ prompt_with_indent = prompt_text.strip() + "\n    "
189
+ output = lora_generator(prompt_with_indent, max_new_tokens=150, temperature=0.7, top_p=0.9, clean_up_tokenization_spaces=True)
190
+ full_output = output[0]["generated_text"]
191
+
192
+ # Extrae solo la parte de la compleción (el código generado)
193
+ start_index = full_output.find(prompt_with_indent)
194
+ completion = full_output[start_index + len(prompt_with_indent):] if start_index != -1 else full_output
195
+
196
+ # 3. Aumentar contador de autonomía
197
+ with global_lock:
198
+ generations_since_last_train += 1
199
+ current_count = generations_since_last_train
200
+ current_version = version_number
201
+
202
+ # 4. Verificar si se requiere reentrenamiento (y dispararlo en un nuevo hilo)
203
+ notification = ""
204
+ if current_count >= GENERATION_LIMIT_TO_TRAIN:
205
+ # Verifica que no haya otro hilo de entrenamiento ya corriendo
206
+ if not any(isinstance(t, threading.Thread) and t.name == 'AutonomousTrainer' for t in threading.enumerate()):
207
+ print(f"[AUTONOMÍA] Generación #{current_count} alcanzada. Disparando reentrenamiento autónomo en segundo plano...")
208
+
209
+ # Reiniciar el contador de generaciones y forzar Hot Swap en la próxima interacción
210
+ with global_lock:
211
+ generations_since_last_train = 0
212
+ lora_generator = None
213
+
214
+ trainer_thread = threading.Thread(
215
+ target=autonomous_train_lora,
216
+ args=(AUTONOMOUS_EPOCHS, 2, 5e-5),
217
+ name='AutonomousTrainer'
218
+ )
219
+ trainer_thread.daemon = True
220
+ trainer_thread.start()
221
+
222
+ notification = f"\n\n--- [AUTONOMÍA] La IA ha iniciado el reentrenamiento V{current_version+0.1:.1f} para mejorar la traducción de tu diálogo. La próxima generación cargará la nueva versión. ---"
223
+
224
+ # CORRECCIÓN CLAVE: Retorna el código Y el estado actualizado
225
+ return completion + notification, update_status()
226
+
227
+ except Exception as e:
228
+ # Si falla la generación, retorna el mensaje de error y el estado actual
229
+ return f"Error generando texto: {e}", update_status()
230
+
231
+ # --- FUNCIÓN PARA INICIALIZACIÓN Y ENTRENAMIENTO V1.0 (Obligatorio) ---
232
+
233
+ def initialize_and_train_v1():
234
+ """Ejecuta el entrenamiento inicial V1.0 de forma autónoma al iniciar."""
235
+ if not is_trained:
236
+ autonomous_train_lora(epochs=DEFAULT_EPOCHS, batch_size=2, learning_rate=5e-5)
237
+ else:
238
+ global training_status_message
239
+ training_status_message = f"✅ Modelo V{version_number:.1f} ya entrenado. Listo."
240
+ print(f"[INICIALIZACIÓN] {training_status_message}")
241
+
242
+ # --- FUNCIÓN PARA ACTUALIZAR EL ESTADO EN LA UI ---
243
+
244
+ def update_status():
245
+ """Actualiza la versión y el estado del entrenamiento en la interfaz de Gradio."""
246
+ global training_status_message, version_number
247
+ # Retorna un texto en Markdown que se actualiza constantemente
248
+ return f"**Versión de Comprensión:** V{version_number:.1f} | **Estado del Entrenador:** {training_status_message}"
249
+
250
+ # --- FUNCIÓN DE CHAT ---
251
+ def chat_response(user_input):
252
+ """Genera una respuesta de chat basado en el modelo pre-entrenado."""
253
+ inputs = chat_tokenizer(user_input, return_tensors="pt")
254
+ outputs = chat_model.generate(**inputs)
255
+ response = chat_tokenizer.decode(outputs[0], skip_special_tokens=True)
256
+ return response
257
+
258
+ # --- INTERFAZ GRADIO ---
259
+ with gr.Blocks(title="AmorCoderAI - Aprendizaje Continuo") as demo:
260
+ gr.Markdown("# 💙 AmorCoderAI - Asistente de Código con Aprendizaje Continuo")
261
+
262
+ # Muestra la versión y el estado.
263
+ version_and_status = gr.Markdown(
264
+ f"**Versión de Comprensión:** V{version_number:.1f} | **Estado del Entrenador:** {training_status_message}",
265
+ elem_id="status_display"
266
+ )
267
+
268
+ gr.Markdown(f"**Modo Autónomo:** La IA se reentrena automáticamente cada **{GENERATION_LIMIT_TO_TRAIN}** códigos generados. Esto mejora su capacidad para traducir tu español conversacional a código.")
269
+
270
+ with gr.Tab("✨ Generación de Código"):
271
+ gr.Markdown("## Escribe tu idea en palabras (¡Usa español fluido!)")
272
+
273
+ gr.Markdown("Recomendación inicial: Usa el siguiente formato para obtener el mejor código mientras la IA aprende tu idioma:")
274
+
275
+ prompt = gr.Textbox(
276
+ label="Instrucción de Programación:",
277
+ lines=4,
278
+ placeholder="# Descripción: Quiero que me hagas un código similar a Google Gemini.\n# Completa la siguiente función:\ndef generar_contenido(prompt, modelo):"
279
+ )
280
+ generate_button = gr.Button("💬 Generar código y disparar Aprendizaje")
281
+ output_box = gr.Textbox(label="Código generado", lines=10)
282
+
283
+ # Conexión del botón con la función principal
284
+ # IMPORTANTE: Ahora generate_text retorna DOS valores para coincidir con [output_box, version_and_status]
285
+ generate_button.click(
286
+ generate_text,
287
+ inputs=prompt,
288
+ outputs=[output_box, version_and_status],
289
+ )
290
+
291
+ with gr.Tab("🗣️ Chat en Español"):
292
+ gr.Markdown("## Habla con la IA en español")
293
+ user_input = gr.Textbox(label="Tu mensaje:", lines=2)
294
+ chat_button = gr.Button("Enviar")
295
+ chat_output = gr.Textbox(label="Respuesta de la IA", lines=5)
296
+
297
+ chat_button.click(
298
+ chat_response,
299
+ inputs=user_input,
300
+ outputs=chat_output
301
+ )
302
+
303
+ # El estado se actualiza solo al cargar la página.
304
+ demo.load(update_status, None, version_and_status)
305
+
306
+ # --- INICIO DE LA APLICACIÓN ---
307
+ if __name__ == "__main__":
308
+ setup_resources()
309
+
310
+ # Lanza el entrenamiento V1.0 inicial en un hilo para que no congele la UI
311
+ initialization_thread = threading.Thread(target=initialize_and_train_v1, name='InitializationTrainer')
312
+ initialization_thread.daemon = True
313
+ initialization_thread.start()
314
+
315
+ print(f"\n💻 LANZANDO INTERFAZ GRADIO (El entrenamiento V1.0 se ejecuta en segundo plano)")
316
+ demo.launch()