Allex21 commited on
Commit
1ab5ebf
·
verified ·
1 Parent(s): 594c14a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -55
app.py CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
4
  from preprocess import process_dataset
5
  import subprocess
6
  import zipfile
 
7
  import time
8
 
9
  def train_lora_interface(
@@ -11,7 +12,8 @@ def train_lora_interface(
11
  num_epochs, hub_token, concept_name, description
12
  ):
13
  if not dataset_input:
14
- return "❌ Por favor, envie um ZIP ou selecione imagens."
 
15
  if not concept_name.strip():
16
  yield "❌ Por favor, defina um nome para o conceito (ex: brenda)."
17
  return
@@ -19,70 +21,78 @@ def train_lora_interface(
19
  yield "❌ Por favor, adicione uma descrição base (ex: mulher, 30 anos, cabelo cacheado)."
20
  return
21
 
22
- concept_name = concept_name.strip().replace(" ", "_") # Evita espaços
23
  full_description = f"{description.strip()}, {concept_name}"
24
 
25
- yield f"🏷️ Conceito: '{concept_name}' → Descrição: '{full_description}'"
26
-
27
- yield "📁 Preparando dataset..."
28
 
29
- # Limpa pasta anterior
30
  dataset_dir = "processed_data"
31
- if os.path.exists(dataset_dir):
32
- for f in os.listdir(dataset_dir):
33
- fp = os.path.join(dataset_dir, f)
34
- try:
35
- if os.path.isfile(fp) or os.path.islink(fp):
36
- os.unlink(fp)
37
- elif os.path.isdir(fp):
38
- os.rmtree(fp)
39
- except Exception as e:
40
- yield f"⚠️ Erro ao limpar: {e}"
41
  os.makedirs(dataset_dir, exist_ok=True)
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  if input_type == "Upload de ZIP":
44
- zip_path = dataset_input
45
- if not zipfile.is_zipfile(zip_path):
46
- yield "❌ Arquivo enviado não é um ZIP válido."
 
 
47
  return
48
-
49
- with zipfile.ZipFile(zip_path, 'r') as z:
 
50
  z.extractall(dataset_dir)
51
- yield f"✅ ZIP descompactado! {len(z.namelist())} arquivos."
52
 
53
  else: # Múltiplas imagens
54
- image_count = 0
55
- for img_path in dataset_input:
56
- dest_path = os.path.join(dataset_dir, os.path.basename(img_path.name))
57
- os.rename(img_path.name, dest_path)
58
- image_count += 1
59
- yield f"✅ {image_count} imagens copiadas."
60
-
61
- # Verifica imagens
62
- image_files = [f for f in os.listdir(dataset_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
 
 
 
 
 
 
 
 
 
 
 
 
63
  if len(image_files) == 0:
64
- yield "❌ Nenhuma imagem encontrada. Envie JPG, PNG ou ZIP com imagens."
65
  return
66
 
67
- yield f"🖼️ Dataset pronto com {len(image_files)} imagens. Gerando legendas personalizadas..."
68
 
69
- # Gera .txt com base no conceito (sem depender só do BLIP)
70
  for img_name in image_files:
71
  txt_path = os.path.join(dataset_dir, os.path.splitext(img_name)[0] + ".txt")
72
- if not os.path.exists(txt_path): # Só cria se não existir
73
  with open(txt_path, "w", encoding="utf-8") as f:
74
  f.write(full_description)
75
 
76
- # Atualiza legendas com BLIP (opcional, mas mantém o conceito)
77
- try:
78
- from preprocess import process_dataset as run_blip
79
- run_blip(None, dataset_dir, generate_captions=False) # Já temos legendas boas!
80
- yield "📌 Legendas personalizadas aplicadas com sucesso!"
81
- except Exception as e:
82
- yield f"⚠️ Falha ao rodar BLIP: {str(e)}. Usando legendas manuais."
83
-
84
- yield "🔥 Iniciando treinamento LoRA..."
85
 
 
86
  output_dir = "lora-output"
87
  os.makedirs(output_dir, exist_ok=True)
88
 
@@ -103,6 +113,8 @@ def train_lora_interface(
103
  cmd.append("--hub_model_id")
104
  cmd.append(f"{concept_name}-lora")
105
 
 
 
106
  try:
107
  process = subprocess.Popen(
108
  cmd,
@@ -123,20 +135,19 @@ def train_lora_interface(
123
 
124
  if process.returncode == 0:
125
  yield f"""
126
- 🎉 SUCESSO! Modelo LoRA treinado!
127
 
128
- 🔹 Nome do conceito: **{concept_name}**
129
- 🔹 Use no prompt: `photo of {concept_name} in a garden`
130
  🔹 Modelo salvo em: `{output_dir}`
131
  {'🔹 Publicado no Hub!' if hub_token else ''}
132
  """
133
  else:
134
- yield f"❌ Treinamento falhou. Código: {process.returncode}\nLogs:\n{log_output[-1000:]}"
135
 
136
  except Exception as e:
137
- yield f"💥 Erro ao executar treinamento: {str(e)}"
138
 
139
- # Interface atualizada
140
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
141
  gr.Markdown("# 🎨 Treinador de LoRA - Hugging Face")
142
  gr.Markdown("Treine personagens, estilos ou objetos com nome e descrição!")
@@ -151,21 +162,21 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
151
  with gr.Row():
152
  dataset_input = gr.File(
153
  label="📤 Envie seu ZIP ou imagens",
154
- file_types=[".zip", ".jpg", ".jpeg", ".png"],
155
  file_count="multiple"
156
  )
157
 
158
  gr.Markdown("### 🔖 Identidade do Personagem/Conceito")
159
  with gr.Row():
160
  concept_name = gr.Textbox(
161
- label="Nome do Conceito (ex: brenda, cyborg_x)",
162
- placeholder="Ex: brenda, super_heroi, estilo_anime",
163
  value=""
164
  )
165
  with gr.Row():
166
  description = gr.Textbox(
167
- label="Descrição Base (ex: mulher, 30 anos, cabelo cacheado)",
168
- placeholder="Ex: young woman, curly hair, green eyes, smiling",
169
  lines=2
170
  )
171
 
@@ -194,6 +205,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
194
  outputs=output
195
  )
196
 
 
197
  demo.queue()
 
198
  if __name__ == "__main__":
199
  demo.launch()
 
4
  from preprocess import process_dataset
5
  import subprocess
6
  import zipfile
7
+ import shutil
8
  import time
9
 
10
  def train_lora_interface(
 
12
  num_epochs, hub_token, concept_name, description
13
  ):
14
  if not dataset_input:
15
+ yield "❌ Por favor, envie um ZIP ou selecione imagens."
16
+ return
17
  if not concept_name.strip():
18
  yield "❌ Por favor, defina um nome para o conceito (ex: brenda)."
19
  return
 
21
  yield "❌ Por favor, adicione uma descrição base (ex: mulher, 30 anos, cabelo cacheado)."
22
  return
23
 
24
+ concept_name = concept_name.strip().replace(" ", "_")
25
  full_description = f"{description.strip()}, {concept_name}"
26
 
27
+ yield f"🏷️ Treinando conceito: '{concept_name}'"
 
 
28
 
29
+ # Pasta de trabalho
30
  dataset_dir = "processed_data"
 
 
 
 
 
 
 
 
 
 
31
  os.makedirs(dataset_dir, exist_ok=True)
32
 
33
+ # Limpa pasta anterior
34
+ for item in os.listdir(dataset_dir):
35
+ item_path = os.path.join(dataset_dir, item)
36
+ try:
37
+ if os.path.isfile(item_path) or os.path.islink(item_path):
38
+ os.unlink(item_path)
39
+ elif os.path.isdir(item_path):
40
+ shutil.rmtree(item_path)
41
+ except Exception as e:
42
+ yield f"⚠️ Erro ao limpar: {e}"
43
+
44
+ # --- ETAPA 1: Processar entrada (ZIP ou múltiplas imagens) ---
45
  if input_type == "Upload de ZIP":
46
+ # dataset_input é uma lista? Pega o primeiro item
47
+ zip_file = dataset_input[0] if isinstance(dataset_input, list) else dataset_input
48
+
49
+ if not zipfile.is_zipfile(zip_file):
50
+ yield "❌ Arquivo não é um ZIP válido."
51
  return
52
+
53
+ yield "📦 Descompactando ZIP..."
54
+ with zipfile.ZipFile(zip_file, 'r') as z:
55
  z.extractall(dataset_dir)
56
+ yield f"✅ ZIP extraído! {len(z.namelist())} arquivos."
57
 
58
  else: # Múltiplas imagens
59
+ image_files = dataset_input if isinstance(dataset_input, list) else [dataset_input]
60
+ yield f"🖼️ Recebidas {len(image_files)} imagens. Copiando..."
61
+
62
+ for uploaded_file in image_files:
63
+ # O Gradio dá um objeto com .name (caminho completo)
64
+ if hasattr(uploaded_file, 'name'):
65
+ src_path = uploaded_file.name
66
+ filename = os.path.basename(src_path)
67
+ dest_path = os.path.join(dataset_dir, filename)
68
+
69
+ # COPIA (não renomeia) porque /tmp é somente leitura
70
+ shutil.copy(src_path, dest_path)
71
+ else:
72
+ yield f"⚠️ Arquivo inválido: {uploaded_file}"
73
+
74
+ yield f"✅ {len(image_files)} imagens copiadas para {dataset_dir}."
75
+
76
+ # --- ETAPA 2: Verifica e gera legendas ---
77
+ image_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.webp')
78
+ image_files = [f for f in os.listdir(dataset_dir) if f.lower().endswith(image_extensions)]
79
+
80
  if len(image_files) == 0:
81
+ yield "❌ Nenhuma imagem encontrada. Envie imagens válidas (.jpg, .png, etc.)."
82
  return
83
 
84
+ yield f"📝 Gerando legendas personalizadas para {len(image_files)} imagens..."
85
 
86
+ # Gera .txt com descrição + conceito
87
  for img_name in image_files:
88
  txt_path = os.path.join(dataset_dir, os.path.splitext(img_name)[0] + ".txt")
89
+ if not os.path.exists(txt_path):
90
  with open(txt_path, "w", encoding="utf-8") as f:
91
  f.write(full_description)
92
 
93
+ yield "🔍 Legendas aplicadas com sucesso!"
 
 
 
 
 
 
 
 
94
 
95
+ # --- ETAPA 3: Treinamento ---
96
  output_dir = "lora-output"
97
  os.makedirs(output_dir, exist_ok=True)
98
 
 
113
  cmd.append("--hub_model_id")
114
  cmd.append(f"{concept_name}-lora")
115
 
116
+ yield "🔥 Iniciando treinamento LoRA... Isso pode levar minutos."
117
+
118
  try:
119
  process = subprocess.Popen(
120
  cmd,
 
135
 
136
  if process.returncode == 0:
137
  yield f"""
138
+ 🎉 TREINAMENTO CONCLUÍDO!
139
 
140
+ 🔹 Use no prompt: `photo of {concept_name} in the forest`
 
141
  🔹 Modelo salvo em: `{output_dir}`
142
  {'🔹 Publicado no Hub!' if hub_token else ''}
143
  """
144
  else:
145
+ yield f"❌ Falha no treinamento. Código: {process.returncode}\nLogs:\n{log_output[-1000:]}"
146
 
147
  except Exception as e:
148
+ yield f"💥 Erro crítico: {str(e)}"
149
 
150
+ # --- Interface Gradio ---
151
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
152
  gr.Markdown("# 🎨 Treinador de LoRA - Hugging Face")
153
  gr.Markdown("Treine personagens, estilos ou objetos com nome e descrição!")
 
162
  with gr.Row():
163
  dataset_input = gr.File(
164
  label="📤 Envie seu ZIP ou imagens",
165
+ file_types=[".zip", ".jpg", ".jpeg", ".png", ".bmp", ".webp"],
166
  file_count="multiple"
167
  )
168
 
169
  gr.Markdown("### 🔖 Identidade do Personagem/Conceito")
170
  with gr.Row():
171
  concept_name = gr.Textbox(
172
+ label="Nome do Conceito (ex: brenda)",
173
+ placeholder="Ex: brenda, super_heroi",
174
  value=""
175
  )
176
  with gr.Row():
177
  description = gr.Textbox(
178
+ label="Descrição Base (ex: woman, curly hair, brown eyes)",
179
+ placeholder="Ex: young black woman, realistic, warm smile",
180
  lines=2
181
  )
182
 
 
205
  outputs=output
206
  )
207
 
208
+ # Ativa suporte a yield
209
  demo.queue()
210
+
211
  if __name__ == "__main__":
212
  demo.launch()