Andro0s commited on
Commit
413d0ff
·
verified ·
1 Parent(s): 2da6038

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -28
app.py CHANGED
@@ -2,7 +2,14 @@ import os
2
  import gradio as gr
3
  from huggingface_hub import login
4
  from datasets import load_dataset
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline
 
 
 
 
 
 
 
6
 
7
  # ============================================================
8
  # 🔐 Autenticación segura con tu token
@@ -12,39 +19,61 @@ hf_token = os.environ.get("HF_TOKEN")
12
  if hf_token:
13
  login(token=hf_token)
14
  else:
15
- print("⚠️ No se encontró el token. Agrega 'HF_TOKEN' en Settings → Secrets.")
16
 
17
  # ============================================================
18
  # ⚙️ Configuración del modelo base y dataset
19
  # ============================================================
20
- MODEL_NAME = "bigcode/santacoder" # Modelo público similar a StarCoder
21
- DATASET_PATH = "dataset.json" # Tu dataset subido al Space
22
- OUTPUT_DIR = "./lora_output"
 
 
 
23
 
24
  # Cargar modelo y tokenizer
25
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
26
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
 
27
 
28
  # ============================================================
29
- # 🧩 Función de entrenamiento LoRA
30
  # ============================================================
31
  def train_lora(epochs, batch_size, learning_rate):
32
  try:
 
33
  dataset = load_dataset("json", data_files=DATASET_PATH)
34
- tokenized = dataset.map(lambda e: tokenizer(e["prompt"] + e["completion"], truncation=True, padding="max_length", max_length=256))
35
 
36
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
 
38
  training_args = TrainingArguments(
39
  output_dir=OUTPUT_DIR,
40
- per_device_train_batch_size=batch_size,
41
- num_train_epochs=epochs,
42
- learning_rate=learning_rate,
43
- save_total_limit=1,
44
  logging_steps=10,
45
- push_to_hub=False
 
 
46
  )
47
 
 
48
  trainer = Trainer(
49
  model=model,
50
  args=training_args,
@@ -52,47 +81,55 @@ def train_lora(epochs, batch_size, learning_rate):
52
  data_collator=data_collator,
53
  )
54
 
 
55
  trainer.train()
 
 
56
  model.save_pretrained(OUTPUT_DIR)
57
  tokenizer.save_pretrained(OUTPUT_DIR)
58
 
59
- return "✅ Entrenamiento completado con éxito y guardado en ./lora_output"
60
  except Exception as e:
61
- return f"❌ Error durante el entrenamiento: {e}"
62
 
63
  # ============================================================
64
- # 🤖 Función de prueba del modelo
65
  # ============================================================
66
  def generate_text(prompt):
67
- generator = pipeline("text-generation", model=OUTPUT_DIR, tokenizer=tokenizer)
68
- output = generator(prompt, max_new_tokens=100, temperature=0.7, top_p=0.9)
69
- return output[0]["generated_text"]
 
 
 
 
 
 
 
70
 
71
  # ============================================================
72
  # 💻 Interfaz de usuario (Gradio)
73
  # ============================================================
74
- with gr.Blocks(title="AmorCoderAI - Entrenamiento LoRA") as demo:
75
  gr.Markdown("# 💙 AmorCoderAI - Entrenamiento y Pruebas")
76
- gr.Markdown("Entrena y prueba tu modelo basado en `bigcode/santacoder` con LoRA")
77
 
78
  with gr.Tab("🧠 Entrenar"):
79
  epochs = gr.Number(value=1, label="Épocas")
80
  batch_size = gr.Number(value=2, label="Tamaño de lote")
81
  learning_rate = gr.Number(value=5e-5, label="Tasa de aprendizaje")
82
  train_button = gr.Button("🚀 Iniciar entrenamiento")
83
- train_output = gr.Textbox(label="Resultado")
84
-
85
  train_button.click(train_lora, inputs=[epochs, batch_size, learning_rate], outputs=train_output)
86
 
87
  with gr.Tab("✨ Probar modelo"):
88
  prompt = gr.Textbox(label="Escribe un prompt")
89
  generate_button = gr.Button("💬 Generar texto")
90
- output_box = gr.Textbox(label="Salida generada")
91
-
92
  generate_button.click(generate_text, inputs=prompt, outputs=output_box)
93
 
94
  # ============================================================
95
  # 🚀 Lanzar app
96
  # ============================================================
97
  if __name__ == "__main__":
98
- demo.launch()
 
2
  import gradio as gr
3
  from huggingface_hub import login
4
  from datasets import load_dataset
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM,
8
+ Trainer,
9
+ TrainingArguments,
10
+ DataCollatorForLanguageModeling,
11
+ pipeline,
12
+ )
13
 
14
  # ============================================================
15
  # 🔐 Autenticación segura con tu token
 
19
  if hf_token:
20
  login(token=hf_token)
21
  else:
22
+ print("⚠️ No se encontró el token. Agrega 'HF_TOKEN' en Settings → Secrets → Add new secret")
23
 
24
  # ============================================================
25
  # ⚙️ Configuración del modelo base y dataset
26
  # ============================================================
27
+ MODEL_NAME = "bigcode/santacoder" # Modelo libre y compatible con Hugging Face
28
+ DATASET_PATH = "dataset.json" # Archivo dataset que subiste al Space
29
+ OUTPUT_DIR = "lora_output" # Carpeta donde se guarda el modelo entrenado
30
+
31
+ # Crear carpeta de salida si no existe
32
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
33
 
34
  # Cargar modelo y tokenizer
35
+ print("🔄 Cargando modelo base...")
36
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=hf_token)
37
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_auth_token=hf_token)
38
 
39
  # ============================================================
40
+ # 🧩 Función de entrenamiento LoRA (simple y funcional)
41
  # ============================================================
42
  def train_lora(epochs, batch_size, learning_rate):
43
  try:
44
+ # Cargar dataset JSON
45
  dataset = load_dataset("json", data_files=DATASET_PATH)
 
46
 
47
+ # Tokenización del dataset
48
+ def tokenize_fn(example):
49
+ text = example["prompt"] + example["completion"]
50
+ return tokenizer(
51
+ text,
52
+ truncation=True,
53
+ padding="max_length",
54
+ max_length=256,
55
+ )
56
+
57
+ tokenized = dataset.map(tokenize_fn, batched=True)
58
+
59
+ # Preparar data collator
60
+ data_collator = DataCollatorForLanguageModeling(
61
+ tokenizer=tokenizer, mlm=False
62
+ )
63
 
64
+ # Configuración del entrenamiento
65
  training_args = TrainingArguments(
66
  output_dir=OUTPUT_DIR,
67
+ per_device_train_batch_size=int(batch_size),
68
+ num_train_epochs=int(epochs),
69
+ learning_rate=float(learning_rate),
 
70
  logging_steps=10,
71
+ save_total_limit=1,
72
+ push_to_hub=False,
73
+ report_to="none",
74
  )
75
 
76
+ # Entrenador
77
  trainer = Trainer(
78
  model=model,
79
  args=training_args,
 
81
  data_collator=data_collator,
82
  )
83
 
84
+ # Entrenar modelo
85
  trainer.train()
86
+
87
+ # Guardar resultados
88
  model.save_pretrained(OUTPUT_DIR)
89
  tokenizer.save_pretrained(OUTPUT_DIR)
90
 
91
+ return "✅ Entrenamiento completado con éxito. Modelo guardado en ./lora_output"
92
  except Exception as e:
93
+ return f"❌ Error durante el entrenamiento: {str(e)}"
94
 
95
  # ============================================================
96
+ # 🤖 Función de prueba del modelo entrenado
97
  # ============================================================
98
  def generate_text(prompt):
99
+ try:
100
+ generator = pipeline(
101
+ "text-generation",
102
+ model=OUTPUT_DIR,
103
+ tokenizer=tokenizer,
104
+ )
105
+ output = generator(prompt, max_new_tokens=100, temperature=0.7, top_p=0.9)
106
+ return output[0]["generated_text"]
107
+ except Exception as e:
108
+ return f"��️ Error al generar texto: {str(e)}"
109
 
110
  # ============================================================
111
  # 💻 Interfaz de usuario (Gradio)
112
  # ============================================================
113
+ with gr.Blocks(title="💙 AmorCoderAI - Entrenamiento LoRA") as demo:
114
  gr.Markdown("# 💙 AmorCoderAI - Entrenamiento y Pruebas")
115
+ gr.Markdown("Entrena y prueba tu modelo basado en `bigcode/santacoder` con LoRA.")
116
 
117
  with gr.Tab("🧠 Entrenar"):
118
  epochs = gr.Number(value=1, label="Épocas")
119
  batch_size = gr.Number(value=2, label="Tamaño de lote")
120
  learning_rate = gr.Number(value=5e-5, label="Tasa de aprendizaje")
121
  train_button = gr.Button("🚀 Iniciar entrenamiento")
122
+ train_output = gr.Textbox(label="Resultado", lines=3)
 
123
  train_button.click(train_lora, inputs=[epochs, batch_size, learning_rate], outputs=train_output)
124
 
125
  with gr.Tab("✨ Probar modelo"):
126
  prompt = gr.Textbox(label="Escribe un prompt")
127
  generate_button = gr.Button("💬 Generar texto")
128
+ output_box = gr.Textbox(label="Salida generada", lines=6)
 
129
  generate_button.click(generate_text, inputs=prompt, outputs=output_box)
130
 
131
  # ============================================================
132
  # 🚀 Lanzar app
133
  # ============================================================
134
  if __name__ == "__main__":
135
+ demo.launch(server_name="0.0.0.0", server_port=7860)