Andro0s commited on
Commit
03b3c9a
·
verified ·
1 Parent(s): 847ca8a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -0
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from huggingface_hub import login
4
+ from datasets import load_dataset
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline
6
+ from peft import get_peft_model, LoraConfig, TaskType, PeftModel
7
+ import json
8
+
9
+ # ============================================================
10
+ # ⚙️ CONFIGURACIÓN GLOBAL
11
+ # ============================================================
12
+ # Modelo base para generación de código
13
+ BASE_MODEL = "bigcode/santacoder" 
14
+ LORA_PATH = "./lora_output"        # Directorio para guardar los adaptadores LoRA
15
+
16
+ # Nombre del archivo donde se guardará el dataset procesado
17
+ DATASET_FILE = "codesearchnet_lora_dataset.json"  
18
+ MAX_TOKEN_LENGTH = 256             # Longitud de secuencia uniforme
19
+ NUM_SAMPLES_TO_PROCESS = 5000     
20
+ DEFAULT_EPOCHS = 5 # <--- ¡ENTRENAMIENTO PROFUNDO!
21
+
22
+ # Variables globales
23
+ tokenizer = None
24
+ lora_model = None
25
+ tokenized_dataset = None
26
+ lora_generator = None
27
+
28
+ # ============================================================
29
+ # 🚨 LÓGICA DE PRE-PROCESAMIENTO DE DATOS (INTEGRADA) 🚨
30
+ # ============================================================
31
+ def prepare_codesearchnet():
32
+     """Descarga, procesa y guarda el dataset CodeSearchNet si no existe."""
33
+     if os.path.exists(DATASET_FILE):
34
+         print(f"✅ Dataset '{DATASET_FILE}' ya existe.")
35
+         return
36
+
37
+     print(f"🔄 Descargando y procesando CodeSearchNet ({NUM_SAMPLES_TO_PROCESS} muestras)...")
38
+    
39
+     try:
40
+         raw_csn = load_dataset('Nan-Do/code-search-net-python', split=f'train[:{NUM_SAMPLES_TO_PROCESS}]')
41
+
42
+         def format_for_lora(example):
43
+             prompt_text = (
44
+                 f"# Descripción: {example['docstring_summary']}\n"
45
+                 f"# Completa la siguiente función:\n"
46
+                 f"def {example['func_name']}("
47
+             )
48
+             completion_text = example['code']
49
+            
50
+             return {
51
+                 "prompt": prompt_text,
52
+                 "completion": completion_text
53
+             }
54
+
55
+         lora_dataset = raw_csn.map(
56
+             format_for_lora,
57
+             batched=False,
58
+             remove_columns=raw_csn["train"].column_names,
59
+         )
60
+
61
+         lora_dataset.to_json(DATASET_FILE)
62
+         print(f"✅ Pre-procesamiento completado. {NUM_SAMPLES_TO_PROCESS} ejemplos guardados en '{DATASET_FILE}'.")
63
+
64
+     except Exception as e:
65
+         print(f"❌ Error CRÍTICO al descargar/procesar CodeSearchNet. Error: {e}")
66
+         minimal_dataset = [{"prompt": "# Error de carga. Intenta de nuevo.", "completion": "pass\n"}] * 10
67
+         with open(DATASET_FILE, 'w') as f:
68
+             json.dump(minimal_dataset, f)
69
+
70
+ # ============================================================
71
+ # 🔐 AUTENTICACIÓN Y PRE-CARGA DE RECURSOS (SINGLETON)
72
+ # ============================================================
73
+
74
+ def setup_resources():
75
+     """Carga y configura todos los recursos (modelo, tokenizer, dataset) una sola vez."""
76
+     global tokenizer, lora_model, tokenized_dataset
77
+
78
+     prepare_codesearchnet()
79
+    
80
+     hf_token = os.environ.get("HF_TOKEN")
81
+     if hf_token:
82
+         login(token=hf_token)
83
+
84
+     # 1. Carga del Tokenizer y Modelo Base
85
+     print("\n🔄 Cargando modelo base y tokenizer...")
86
+     tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
87
+     base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
88
+
89
+     if tokenizer.pad_token is None:
90
+         tokenizer.pad_token = tokenizer.eos_token
91
+
92
+     # 2. Configuración y Aplicación LoRA (PEFT)
93
+     peft_config = LoraConfig(
94
+         task_type=TaskType.CAUSAL_LM,
95
+         r=8,
96
+         lora_alpha=32,
97
+         lora_dropout=0.1,
98
+         target_modules=["c_proj", "c_attn"],
99
+     )
100
+     lora_model = get_peft_model(base_model, peft_config)
101
+    
102
+     print(f"✅ Modelo LoRA preparado. Parámetros entrenables listos.")
103
+
104
+     # 3. Carga y Tokenización del Dataset
105
+     print(f"📚 Cargando y tokenizando dataset: {DATASET_FILE}...")
106
+     try:
107
+         raw_dataset = load_dataset("json", data_files=DATASET_FILE)
108
+        
109
+         def tokenize_function(examples):
110
+             return tokenizer(
111
+                 examples["prompt"] + examples["completion"],
112
+                 truncation=True,
113
+                 padding="max_length",
114
+                 max_length=MAX_TOKEN_LENGTH
115
+             )
116
+
117
+         tokenized_dataset = raw_dataset.map(
118
+             tokenize_function,
119
+             batched=True,
120
+             remove_columns=raw_dataset["train"].column_names if "train" in raw_dataset else [],
121
+         )
122
+         print("✅ Dataset tokenizado correctamente.")
123
+     except Exception as e:
124
+         tokenized_dataset = None
125
+         print(f"❌ Error al cargar o tokenizar el dataset. {e}")
126
+
127
+
128
+ # ============================================================
129
+ # 🧩 FUNCIÓN DE ENTRENAMIENTO
130
+ # ============================================================
131
+ def train_lora(epochs, batch_size, learning_rate):
132
+     """Ejecuta el entrenamiento del modelo LoRA."""
133
+     global lora_model, tokenized_dataset, lora_generator
134
+
135
+     if tokenized_dataset is None or "train" not in tokenized_dataset:
136
+         return f"❌ Error: El dataset no pudo cargarse o está vacío. No se puede entrenar."
137
+
138
+     try:
139
+         lora_generator = None
140
+         data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
141
+
142
+         training_args = TrainingArguments(
143
+             output_dir=LORA_PATH,
144
+             per_device_train_batch_size=int(batch_size),
145
+             num_train_epochs=float(epochs),
146
+             learning_rate=float(learning_rate),
147
+             save_total_limit=1,
148
+             logging_steps=10,
149
+             push_to_hub=False,
150
+         )
151
+
152
+         trainer = Trainer(
153
+             model=lora_model,
154
+             args=training_args,
155
+             train_dataset=tokenized_dataset["train"],
156
+             data_collator=data_collator,
157
+         )
158
+
159
+         trainer.train()
160
+        
161
+         lora_model.save_pretrained(LORA_PATH)
162
+         tokenizer.save_pretrained(LORA_PATH)
163
+
164
+         return f"✅ Entrenamiento completado. Adaptadores LoRA guardados en **{LORA_PATH}**"
165
+     except Exception as e:
166
+         return f"❌ Error durante el entrenamiento: {e}"
167
+
168
+ # ============================================================
169
+ # 🤖 FUNCIÓN DE GENERACIÓN (INFERENCIA)
170
+ # ============================================================
171
+ def generate_text(prompt_text):
172
+     """Genera texto usando el modelo base + adaptadores LoRA."""
173
+     global lora_generator
174
+
175
+     try:
176
+         if lora_generator is None:
177
+             base_model_gen = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
178
+            
179
+             if os.path.exists(LORA_PATH):
180
+                 print("Cargando adaptadores LoRA...")
181
+                 model_with_lora = PeftModel.from_pretrained(base_model_gen, LORA_PATH)
182
+             else:
183
+                 print("No se encontraron adaptadores LoRA. Usando modelo base.")
184
+                 model_with_lora = base_model_gen
185
+
186
+             final_model = model_with_lora.merge_and_unload()
187
+             final_model.eval()
188
+            
189
+             lora_generator = pipeline("text-generation", model=final_model, tokenizer=tokenizer)
190
+             print("Modelo de inferencia listo.")
191
+
192
+         output = lora_generator(prompt_text, max_new_tokens=150, temperature=0.7, top_p=0.9)
193
+         return output[0]["generated_text"]
194
+        
195
+     except Exception as e:
196
+         return f"❌ Error generando texto (Asegúrate de que el modelo base/LoRA esté cargado): {e}"
197
+
198
+ # ============================================================
199
+ # 💻 INTERFAZ GRADIO
200
+ # ============================================================
201
+ with gr.Blocks(title="AmorCoderAI - LoRA") as demo:
202
+     gr.Markdown("# 💙 AmorCoderAI - Entrenamiento y Pruebas LoRA")
203
+     gr.Markdown(f"Modelo base: `{BASE_MODEL}`. Usando **{NUM_SAMPLES_TO_PROCESS}** ejemplos de CodeSearchNet.")
204
+
205
+     with gr.Tab("🧠 Entrenar (Manual)"):
206
+         gr.Markdown(f"--- **¡CUIDADO!** El auto-entrenamiento usará {DEFAULT_EPOCHS} épocas para aprender la sintaxis. ---")
207
+         epochs = gr.Number(value=DEFAULT_EPOCHS, label="Épocas", precision=0)
208
+         batch_size = gr.Number(value=2, label="Tamaño de lote (ajusta según tu VRAM)", precision=0)
209
+         learning_rate = gr.Number(value=5e-5, label="Tasa de aprendizaje")
210
+         train_button = gr.Button("🚀 Iniciar Entrenamiento Manual")
211
+         train_output = gr.Textbox(label="Resultado del Entrenamiento Manual")
212
+        
213
+         train_button.click(
214
+             train_lora,
215
+             inputs=[epochs, batch_size, learning_rate],
216
+             outputs=train_output
217
+         )
218
+
219
+     with gr.Tab("✨ Probar modelo"):
220
+         prompt = gr.Textbox(label="Escribe código (ej: 'def fibonacci(n):')", lines=4)
221
+         generate_button = gr.Button("💬 Generar código")
222
+         output_box = gr.Textbox(label="Salida generada", lines=10)
223
+         generate_button.click(generate_text, inputs=prompt, outputs=output_box)
224
+
225
+ # ============================================================
226
+ # 🚀 LANZAR APP Y AUTO-ENTRENAMIENTO
227
+ # ============================================================
228
+ if __name__ == "__main__":
229
+     setup_resources()
230
+    
231
+     print("\n=============================================")
232
+     print(f"🤖 INICIANDO AUTO-ENTRENAMIENTO ({DEFAULT_EPOCHS} Épocas, 2 Batch Size) usando {NUM_SAMPLES_TO_PROCESS} ejemplos")
233
+     print("=============================================")
234
+    
235
+     auto_train_result = train_lora(epochs=DEFAULT_EPOCHS, batch_size=2, learning_rate=5e-5)
236
+    
237
+     print(f"\nFIN DEL AUTO-ENTRENAMIENTO: {auto_train_result}")
238
+    
239
+     print("\n=============================================")
240
+     print("💻 LANZANDO INTERFAZ GRADIO")
241
+     print("=============================================")
242
+     demo.launch()