Trabre / app.py
Allex21's picture
Update app.py
81d438f verified
# app.py
import os
import gradio as gr
from preprocess import process_dataset
import subprocess
import zipfile
import shutil
import time
def train_lora_interface(
dataset_input, input_type, model_name, lora_rank, learning_rate,
num_epochs, hub_token, concept_name, description
):
if not dataset_input:
yield "❌ Por favor, envie um ZIP ou selecione imagens."
return
if not concept_name.strip():
yield "❌ Por favor, defina um nome para o conceito (ex: brenda)."
return
if not description.strip():
yield "❌ Por favor, adicione uma descrição base."
return
concept_name = concept_name.strip().replace(" ", "_")
full_description = f"{description.strip()}, {concept_name}"
yield f"🏷️ Treinando: '{concept_name}'"
dataset_dir = "processed_data"
os.makedirs(dataset_dir, exist_ok=True)
# Limpa pasta anterior
for item in os.listdir(dataset_dir):
item_path = os.path.join(dataset_dir, item)
try:
if os.path.isfile(item_path) or os.path.islink(item_path):
os.unlink(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
except Exception as e:
yield f"⚠️ Erro ao limpar: {e}"
# --- ETAPA 1: Processar entrada ---
if input_type == "Upload de ZIP":
zip_file = dataset_input[0] if isinstance(dataset_input, list) else dataset_input
if not zipfile.is_zipfile(zip_file):
yield "❌ Arquivo não é um ZIP válido."
return
yield "📦 Descompactando..."
with zipfile.ZipFile(zip_file, 'r') as z:
z.extractall(dataset_dir)
yield f"✅ ZIP extraído! {len(z.namelist())} arquivos."
else:
image_files = dataset_input if isinstance(dataset_input, list) else [dataset_input]
yield f"🖼️ Copiando {len(image_files)} imagens..."
for uploaded_file in image_files:
if hasattr(uploaded_file, 'name'):
src = uploaded_file.name
dest = os.path.join(dataset_dir, os.path.basename(src))
shutil.copy(src, dest)
yield f"✅ {len(image_files)} imagens copiadas."
# --- ETAPA 2: Gera legendas ---
exts = ('.png', '.jpg', '.jpeg', '.bmp', '.webp')
images = [f for f in os.listdir(dataset_dir) if f.lower().endswith(exts)]
if len(images) == 0:
yield "❌ Nenhuma imagem encontrada!"
return
yield f"📝 Aplicando legenda: '{full_description}'"
for img in images:
txt = os.path.join(dataset_dir, os.path.splitext(img)[0] + ".txt")
if not os.path.exists(txt):
with open(txt, "w", encoding="utf-8") as f:
f.write(full_description)
yield "🔍 Legendas prontas!"
# --- ETAPA 3: Treinamento ---
output_dir = "lora-output"
os.makedirs(output_dir, exist_ok=True)
cmd = [
"python", "train_lora.py",
"--dataset_dir", dataset_dir,
"--model_name", model_name,
"--lora_rank", str(lora_rank),
"--learning_rate", str(learning_rate),
"--num_epochs", str(num_epochs),
"--batch_size", "1",
"--output_dir", output_dir
]
if hub_token:
os.environ["HF_TOKEN"] = hub_token
cmd += ["--push_to_hub", "--hub_model_id", f"{concept_name}-lora"]
yield "🔥 Iniciando treinamento..."
try:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
bufsize=1,
encoding='utf-8'
)
log_output = ""
for line in process.stdout:
log_output += line
if "loss" in line.lower() or "epoch" in line.lower():
yield f"📊 {line.strip()}"
process.wait()
if process.returncode == 0:
yield f"""
🎉 SUCESSO!
🔹 Use no prompt: `photo of {concept_name} in the forest`
🔹 Modelo salvo em: `{output_dir}`
{'🔹 Publicado no Hub!' if hub_token else ''}
"""
else:
yield f"❌ Falha no treinamento. Código: {process.returncode}\nLogs:\n{log_output[-1000:]}"
except Exception as e:
yield f"💥 Erro: {str(e)}"
# --- Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎨 Treinador de LoRA - Hugging Face")
gr.Markdown("Treine personagens, estilos ou objetos personalizados.")
with gr.Row():
input_type = gr.Radio(
["Upload de ZIP", "Selecionar várias imagens"],
label="Tipo de Entrada",
value="Upload de ZIP"
)
with gr.Row():
dataset_input = gr.File(
label="📤 Envie seu ZIP ou imagens",
file_types=[".zip", ".jpg", ".jpeg", ".png", ".bmp", ".webp"],
file_count="multiple"
)
gr.Markdown("### 🔖 Identidade do Personagem")
with gr.Row():
concept_name = gr.Textbox(
label="Nome do Conceito (ex: brenda)",
placeholder="Ex: brenda, cyborg_x",
value=""
)
with gr.Row():
description = gr.Textbox(
label="Descrição Base (ex: woman, curly hair)",
placeholder="Ex: young black woman, realistic style",
lines=2
)
gr.Markdown("### ⚙️ Configurações")
with gr.Row():
model_name = gr.Dropdown(
["runwayml/stable-diffusion-v1-5"],
value="runwayml/stable-diffusion-v1-5",
label="Modelo Base"
)
lora_rank = gr.Slider(4, 64, value=4, step=4, label="LoRA Rank")
learning_rate = gr.Number(value=1e-4, label="Taxa de Aprendizado")
num_epochs = gr.Slider(1, 30, value=10, step=1, label="Épocas")
hub_token = gr.Textbox(label="🔐 Token do HF (opcional)", type="password")
btn = gr.Button("🚀 Iniciar Treinamento", variant="primary")
output = gr.Textbox(label="📦 Logs", lines=12)
btn.click(
train_lora_interface,
inputs=[
dataset_input, input_type, model_name, lora_rank,
learning_rate, num_epochs, hub_token, concept_name, description
],
outputs=output
)
demo.queue()
if __name__ == "__main__":
demo.launch()