Allex21 commited on
Commit
09dea70
·
verified ·
1 Parent(s): 9275790

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -47
app.py CHANGED
@@ -3,43 +3,94 @@ import os
3
  import gradio as gr
4
  from preprocess import process_dataset
5
  import subprocess
 
6
  import time
7
 
8
- def train_lora_interface(dataset_zip, model_name, lora_rank, learning_rate, num_epochs, hub_token):
9
- if not dataset_zip:
10
- return "❌ Por favor, envie um ZIP com suas imagens."
11
-
12
- # Etapa 1: Pré-processamento
13
- yield "🔄 Descompactando e processando dataset..."
14
- try:
15
- dataset_dir = process_dataset(dataset_zip, "processed_data")
16
- image_count = len([f for f in os.listdir(dataset_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
17
- yield f"✅ Dataset processado: {image_count} imagens encontradas. Iniciando treinamento..."
18
-
19
- # Etapa 2: Configura treinamento
20
- output_dir = "lora-output"
21
- os.makedirs(output_dir, exist_ok=True)
22
-
23
- cmd = [
24
- "python", "train_lora.py",
25
- "--dataset_dir", dataset_dir,
26
- "--model_name", model_name,
27
- "--lora_rank", str(lora_rank),
28
- "--learning_rate", str(learning_rate),
29
- "--num_epochs", str(num_epochs),
30
- "--batch_size", "1",
31
- "--output_dir", output_dir
32
- ]
33
-
34
- if hub_token:
35
- os.environ["HF_TOKEN"] = hub_token
36
- cmd.append("--push_to_hub")
37
- cmd.append("--hub_model_id")
38
- cmd.append("my-lora-model")
39
-
40
- # Etapa 3: Executa treinamento
41
- yield "🔥 Treinando modelo... Isso pode levar alguns minutos."
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  process = subprocess.Popen(
44
  cmd,
45
  stdout=subprocess.PIPE,
@@ -52,28 +103,39 @@ def train_lora_interface(dataset_zip, model_name, lora_rank, learning_rate, num_
52
  log_output = ""
53
  for line in process.stdout:
54
  log_output += line
55
- # Mostra os últimos logs a cada 50 linhas
56
- if len(log_output.split('\n')) % 20 == 0:
57
- yield f"📝 Treinando...\n{log_output[-500:]}"
58
 
59
  process.wait()
60
 
61
  if process.returncode == 0:
62
- yield f"🎉 Treinamento concluído! Modelo salvo em `{output_dir}`"
63
  else:
64
- yield f"❌ Falha no treinamento. Código: {process.returncode}\nÚltimos logs:\n{log_output[-1000:]}"
65
-
66
  except Exception as e:
67
- yield f"💥 Erro crítico: {str(e)}\n\nVerifique seus arquivos e tente novamente."
68
 
69
- # Interface
70
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
71
  gr.Markdown("# 🎨 Treinador de LoRA - Hugging Face")
72
- gr.Markdown("Treine modelos personalizados com apenas algumas imagens!")
73
 
74
  with gr.Row():
75
- dataset_zip = gr.File(label="📤 Envie um ZIP com suas imagens (.jpg, .png)")
 
 
 
 
 
 
 
 
 
 
 
76
 
 
77
  with gr.Row():
78
  model_name = gr.Dropdown(
79
  ["runwayml/stable-diffusion-v1-5"],
@@ -87,15 +149,15 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
87
  hub_token = gr.Textbox(label="🔐 Token do Hugging Face (opcional)", type="password")
88
 
89
  btn = gr.Button("🚀 Iniciar Treinamento", variant="primary")
90
- output = gr.Textbox(label="📊 Status / Logs", lines=10)
91
 
92
  btn.click(
93
  train_lora_interface,
94
- inputs=[dataset_zip, model_name, lora_rank, learning_rate, num_epochs, hub_token],
95
  outputs=output
96
  )
97
 
98
- # Ativa fila para suporte a `yield`
99
  demo.queue()
100
 
101
  if __name__ == "__main__":
 
3
  import gradio as gr
4
  from preprocess import process_dataset
5
  import subprocess
6
+ import zipfile
7
  import time
8
 
9
+ def train_lora_interface(dataset_input, input_type, model_name, lora_rank, learning_rate, num_epochs, hub_token):
10
+ if not dataset_input:
11
+ return "❌ Por favor, envie um ZIP ou selecione imagens."
12
+
13
+ yield "📁 Preparando dataset..."
14
+
15
+ # Pasta temporária
16
+ os.makedirs("processed_data", exist_ok=True)
17
+ for f in os.listdir("processed_data"):
18
+ fp = os.path.join("processed_data", f)
19
+ try:
20
+ if os.path.isfile(fp) or os.path.islink(fp):
21
+ os.unlink(fp)
22
+ elif os.path.isdir(fp):
23
+ os.rmtree(fp)
24
+ except Exception as e:
25
+ yield f"⚠️ Erro ao limpar: {e}"
26
+
27
+ dataset_dir = "processed_data"
28
+
29
+ if input_type == "Upload de ZIP":
30
+ zip_path = dataset_input
31
+ if not zipfile.is_zipfile(zip_path):
32
+ yield "❌ Arquivo enviado não é um ZIP válido."
33
+ return
 
 
 
 
 
 
 
 
 
34
 
35
+ # Descompacta
36
+ with zipfile.ZipFile(zip_path, 'r') as z:
37
+ z.extractall(dataset_dir)
38
+ yield f"✅ ZIP descompactado! {len(z.namelist())} arquivos extraídos."
39
+
40
+ else: # Múltiplas imagens
41
+ image_count = 0
42
+ for img_path in dataset_input:
43
+ dest_path = os.path.join(dataset_dir, os.path.basename(img_path.name))
44
+ os.rename(img_path.name, dest_path) # Move para processed_data
45
+
46
+ # Cria .txt vazio (BLIP vai preencher depois)
47
+ txt_path = os.path.splitext(dest_path)[0] + ".txt"
48
+ if not os.path.exists(txt_path):
49
+ with open(txt_path, "w") as f:
50
+ f.write("person")
51
+ image_count += 1
52
+ yield f"✅ {image_count} imagens copiadas para o dataset."
53
+
54
+ # Conta imagens processadas
55
+ image_files = [f for f in os.listdir(dataset_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
56
+ if len(image_files) == 0:
57
+ yield "❌ Nenhuma imagem encontrada. Envie JPG, PNG ou ZIP com imagens."
58
+ return
59
+
60
+ yield f"🖼️ Dataset pronto com {len(image_files)} imagens. Iniciando pré-processamento..."
61
+
62
+ # Gera captions com BLIP (reutiliza preprocess.py)
63
+ try:
64
+ from preprocess import process_dataset as run_blip
65
+ run_blip(None, dataset_dir, generate_captions=True) # Já temos as imagens, só gera captions
66
+ except Exception as e:
67
+ yield f"⚠️ Falha ao gerar legendas: {str(e)}. Continuando com captions existentes."
68
+
69
+ yield "🔥 Iniciando treinamento LoRA..."
70
+
71
+ # Comando de treinamento
72
+ output_dir = "lora-output"
73
+ os.makedirs(output_dir, exist_ok=True)
74
+
75
+ cmd = [
76
+ "python", "train_lora.py",
77
+ "--dataset_dir", dataset_dir,
78
+ "--model_name", model_name,
79
+ "--lora_rank", str(lora_rank),
80
+ "--learning_rate", str(learning_rate),
81
+ "--num_epochs", str(num_epochs),
82
+ "--batch_size", "1",
83
+ "--output_dir", output_dir
84
+ ]
85
+
86
+ if hub_token:
87
+ os.environ["HF_TOKEN"] = hub_token
88
+ cmd.append("--push_to_hub")
89
+ cmd.append("--hub_model_id")
90
+ cmd.append("my-lora-model")
91
+
92
+ # Executa treinamento
93
+ try:
94
  process = subprocess.Popen(
95
  cmd,
96
  stdout=subprocess.PIPE,
 
103
  log_output = ""
104
  for line in process.stdout:
105
  log_output += line
106
+ if "loss" in line.lower() or "epoch" in line.lower():
107
+ yield f"📊 {line.strip()}"
 
108
 
109
  process.wait()
110
 
111
  if process.returncode == 0:
112
+ yield f"🎉 SUCESSO! Modelo LoRA treinado e salvo em `{output_dir}`"
113
  else:
114
+ yield f"❌ Treinamento falhou com código {process.returncode}.\nLogs:\n{log_output[-1000:]}"
115
+
116
  except Exception as e:
117
+ yield f"💥 Erro ao executar treinamento: {str(e)}"
118
 
119
+ # Interface com opção de entrada
120
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
121
  gr.Markdown("# 🎨 Treinador de LoRA - Hugging Face")
122
+ gr.Markdown("Envie suas imagens via ZIP ou múltiplos arquivos!")
123
 
124
  with gr.Row():
125
+ input_type = gr.Radio(
126
+ ["Upload de ZIP", "Selecionar várias imagens"],
127
+ label="Tipo de Entrada",
128
+ value="Upload de ZIP"
129
+ )
130
+
131
+ with gr.Row():
132
+ dataset_input = gr.File(
133
+ label="📤 Envie seu ZIP ou imagens",
134
+ file_types=[".zip", ".jpg", ".jpeg", ".png"],
135
+ file_count="multiple" # Permite múltiplos arquivos!
136
+ )
137
 
138
+ gr.Markdown("### ⚙️ Configurações do Treinamento")
139
  with gr.Row():
140
  model_name = gr.Dropdown(
141
  ["runwayml/stable-diffusion-v1-5"],
 
149
  hub_token = gr.Textbox(label="🔐 Token do Hugging Face (opcional)", type="password")
150
 
151
  btn = gr.Button("🚀 Iniciar Treinamento", variant="primary")
152
+ output = gr.Textbox(label="📦 Logs e Status", lines=12)
153
 
154
  btn.click(
155
  train_lora_interface,
156
+ inputs=[dataset_input, input_type, model_name, lora_rank, learning_rate, num_epochs, hub_token],
157
  outputs=output
158
  )
159
 
160
+ # Ativa fila para yield
161
  demo.queue()
162
 
163
  if __name__ == "__main__":