Allex21 commited on
Commit
9275790
·
verified ·
1 Parent(s): 5b0bbfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -47
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import gradio as gr
3
  from preprocess import process_dataset
@@ -5,18 +6,20 @@ import subprocess
5
  import time
6
 
7
  def train_lora_interface(dataset_zip, model_name, lora_rank, learning_rate, num_epochs, hub_token):
8
- # 1. Pré-processamento
9
- with gr.Progress() as progress:
10
- progress(0, "Descompactando e processando dataset...")
 
 
 
11
  dataset_dir = process_dataset(dataset_zip, "processed_data")
12
-
13
- # 2. Configura treinamento
14
- progress(0.3, "Configurando treinamento LoRA...")
 
15
  output_dir = "lora-output"
16
  os.makedirs(output_dir, exist_ok=True)
17
-
18
- # 3. Executa treinamento
19
- progress(0.5, "Treinando modelo (isso pode levar horas)...")
20
  cmd = [
21
  "python", "train_lora.py",
22
  "--dataset_dir", dataset_dir,
@@ -24,66 +27,76 @@ def train_lora_interface(dataset_zip, model_name, lora_rank, learning_rate, num_
24
  "--lora_rank", str(lora_rank),
25
  "--learning_rate", str(learning_rate),
26
  "--num_epochs", str(num_epochs),
 
27
  "--output_dir", output_dir
28
  ]
29
-
30
  if hub_token:
31
  os.environ["HF_TOKEN"] = hub_token
32
  cmd.append("--push_to_hub")
33
  cmd.append("--hub_model_id")
34
  cmd.append("my-lora-model")
 
 
 
35
 
36
  process = subprocess.Popen(
37
- cmd,
38
- stdout=subprocess.PIPE,
39
  stderr=subprocess.STDOUT,
40
- universal_newlines=True
 
 
41
  )
42
-
43
- logs = ""
44
  for line in process.stdout:
45
- logs += line
46
- progress(0.7, f"Treinando...\n{logs[-200:]}")
47
-
48
- # 4. Finalização
49
- progress(0.9, "Subindo para Hugging Face Hub...")
50
- if hub_token:
51
- from huggingface_hub import upload_folder
52
- upload_folder(
53
- repo_id="my-lora-model",
54
- folder_path=output_dir,
55
- token=hub_token
56
- )
57
-
58
- progress(1.0, "Treinamento concluído com sucesso!")
59
- return f"Modelo salvo em: {output_dir}\nLogs: {logs[-500:]}"
60
 
61
- # Interface Gradio
62
- with gr.Blocks(title="LoRA Trainer - Hugging Face") as demo:
63
- gr.Markdown("# 🚀 Treinador de LoRA para Stable Diffusion")
64
- gr.Markdown("Treine seus próprios modelos LoRA diretamente no Hugging Face Spaces!")
65
 
66
- with gr.Tab("Configuração"):
67
- dataset_zip = gr.File(label="Dataset (ZIP com imagens)", file_types=['.zip'])
 
 
 
 
 
 
 
 
 
 
68
  model_name = gr.Dropdown(
69
- ["runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-2-1"],
70
  value="runwayml/stable-diffusion-v1-5",
71
  label="Modelo Base"
72
  )
73
- lora_rank = gr.Slider(4, 64, value=4, step=4, label="Rank LoRA")
74
  learning_rate = gr.Number(value=1e-4, label="Taxa de Aprendizado")
75
- num_epochs = gr.Slider(1, 50, value=10, step=1, label="Épocas")
76
- hub_token = gr.Textbox(label="Token Hugging Face (opcional)", type="password")
77
-
78
- with gr.Tab("Treinamento"):
79
- start_btn = gr.Button("🚀 Iniciar Treinamento", variant="primary")
80
- output = gr.Textbox(label="Logs do Treinamento")
81
-
82
- start_btn.click(
83
- fn=train_lora_interface,
84
  inputs=[dataset_zip, model_name, lora_rank, learning_rate, num_epochs, hub_token],
85
  outputs=output
86
  )
87
 
 
 
 
88
  if __name__ == "__main__":
89
  demo.launch()
 
1
+ # app.py
2
  import os
3
  import gradio as gr
4
  from preprocess import process_dataset
 
6
  import time
7
 
8
  def train_lora_interface(dataset_zip, model_name, lora_rank, learning_rate, num_epochs, hub_token):
9
+ if not dataset_zip:
10
+ return "❌ Por favor, envie um ZIP com suas imagens."
11
+
12
+ # Etapa 1: Pré-processamento
13
+ yield "🔄 Descompactando e processando dataset..."
14
+ try:
15
  dataset_dir = process_dataset(dataset_zip, "processed_data")
16
+ image_count = len([f for f in os.listdir(dataset_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
17
+ yield f"✅ Dataset processado: {image_count} imagens encontradas. Iniciando treinamento..."
18
+
19
+ # Etapa 2: Configura treinamento
20
  output_dir = "lora-output"
21
  os.makedirs(output_dir, exist_ok=True)
22
+
 
 
23
  cmd = [
24
  "python", "train_lora.py",
25
  "--dataset_dir", dataset_dir,
 
27
  "--lora_rank", str(lora_rank),
28
  "--learning_rate", str(learning_rate),
29
  "--num_epochs", str(num_epochs),
30
+ "--batch_size", "1",
31
  "--output_dir", output_dir
32
  ]
33
+
34
  if hub_token:
35
  os.environ["HF_TOKEN"] = hub_token
36
  cmd.append("--push_to_hub")
37
  cmd.append("--hub_model_id")
38
  cmd.append("my-lora-model")
39
+
40
+ # Etapa 3: Executa treinamento
41
+ yield "🔥 Treinando modelo... Isso pode levar alguns minutos."
42
 
43
  process = subprocess.Popen(
44
+ cmd,
45
+ stdout=subprocess.PIPE,
46
  stderr=subprocess.STDOUT,
47
+ universal_newlines=True,
48
+ bufsize=1,
49
+ encoding='utf-8'
50
  )
51
+
52
+ log_output = ""
53
  for line in process.stdout:
54
+ log_output += line
55
+ # Mostra os últimos logs a cada 50 linhas
56
+ if len(log_output.split('\n')) % 20 == 0:
57
+ yield f"📝 Treinando...\n{log_output[-500:]}"
58
+
59
+ process.wait()
 
 
 
 
 
 
 
 
 
60
 
61
+ if process.returncode == 0:
62
+ yield f"🎉 Treinamento concluído! Modelo salvo em `{output_dir}`"
63
+ else:
64
+ yield f" Falha no treinamento. Código: {process.returncode}\nÚltimos logs:\n{log_output[-1000:]}"
65
 
66
+ except Exception as e:
67
+ yield f"💥 Erro crítico: {str(e)}\n\nVerifique seus arquivos e tente novamente."
68
+
69
+ # Interface
70
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
71
+ gr.Markdown("# 🎨 Treinador de LoRA - Hugging Face")
72
+ gr.Markdown("Treine modelos personalizados com apenas algumas imagens!")
73
+
74
+ with gr.Row():
75
+ dataset_zip = gr.File(label="📤 Envie um ZIP com suas imagens (.jpg, .png)")
76
+
77
+ with gr.Row():
78
  model_name = gr.Dropdown(
79
+ ["runwayml/stable-diffusion-v1-5"],
80
  value="runwayml/stable-diffusion-v1-5",
81
  label="Modelo Base"
82
  )
83
+ lora_rank = gr.Slider(4, 64, value=4, step=4, label="LoRA Rank")
84
  learning_rate = gr.Number(value=1e-4, label="Taxa de Aprendizado")
85
+ num_epochs = gr.Slider(1, 30, value=10, step=1, label="Épocas")
86
+
87
+ hub_token = gr.Textbox(label="🔐 Token do Hugging Face (opcional)", type="password")
88
+
89
+ btn = gr.Button("🚀 Iniciar Treinamento", variant="primary")
90
+ output = gr.Textbox(label="📊 Status / Logs", lines=10)
91
+
92
+ btn.click(
93
+ train_lora_interface,
94
  inputs=[dataset_zip, model_name, lora_rank, learning_rate, num_epochs, hub_token],
95
  outputs=output
96
  )
97
 
98
+ # Ativa fila para suporte a `yield`
99
+ demo.queue()
100
+
101
  if __name__ == "__main__":
102
  demo.launch()