Trabre / app.py
Allex21's picture
Update app.py
bba7e77 verified
raw
history blame
7.94 kB
# app.py
# ⚠️ NÃO REMOVA: Força instalação de pacotes críticos antes de qualquer import
import subprocess
import sys
import os
import shutil
def install_packages():
print("🔧 Forçando reinstalação de diffusers e huggingface_hub...")
try:
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"--force-reinstall",
"diffusers>=0.26.0",
"huggingface-hub>=0.20.0",
"accelerate",
"peft",
"torch==2.3.0",
"transformers==4.40.0"
])
print("✅ Pacotes essenciais instalados!")
except Exception as e:
print(f"❌ Falha na instalação forçada: {e}")
# Executa apenas uma vez por sessão
if not os.path.exists("/tmp/packages_installed"):
install_packages()
with open("/tmp/packages_installed", "w") as f:
f.write("ok")
# Agora sim, imports seguros
import gradio as gr
from preprocess import process_dataset
import subprocess
import zipfile
import time
def train_lora_interface(
dataset_input, input_type, model_name, lora_rank, learning_rate,
num_epochs, hub_token, concept_name, description
):
if not dataset_input:
yield "❌ Por favor, envie um ZIP ou selecione imagens."
return
if not concept_name.strip():
yield "❌ Por favor, defina um nome para o conceito (ex: brenda)."
return
if not description.strip():
yield "❌ Por favor, adicione uma descrição base (ex: mulher, 30 anos, cabelo cacheado)."
return
concept_name = concept_name.strip().replace(" ", "_")
full_description = f"{description.strip()}, {concept_name}"
yield f"🏷️ Treinando conceito: '{concept_name}' → Prompt: [photo of {concept_name}]"
# Pasta de trabalho
dataset_dir = "processed_data"
os.makedirs(dataset_dir, exist_ok=True)
# Limpa pasta anterior
for item in os.listdir(dataset_dir):
item_path = os.path.join(dataset_dir, item)
try:
if os.path.isfile(item_path) or os.path.islink(item_path):
os.unlink(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
except Exception as e:
yield f"⚠️ Erro ao limpar: {e}"
# --- ETAPA 1: Processar entrada ---
if input_type == "Upload de ZIP":
zip_file = dataset_input[0] if isinstance(dataset_input, list) else dataset_input
if not zipfile.is_zipfile(zip_file):
yield "❌ Arquivo não é um ZIP válido."
return
yield "📦 Descompactando ZIP..."
with zipfile.ZipFile(zip_file, 'r') as z:
z.extractall(dataset_dir)
yield f"✅ ZIP extraído! {len(z.namelist())} arquivos."
else: # Múltiplas imagens
image_files = dataset_input if isinstance(dataset_input, list) else [dataset_input]
yield f"🖼️ Recebidas {len(image_files)} imagens. Copiando..."
for uploaded_file in image_files:
if hasattr(uploaded_file, 'name'):
src_path = uploaded_file.name
filename = os.path.basename(src_path)
dest_path = os.path.join(dataset_dir, filename)
shutil.copy(src_path, dest_path) # Usa copy, não rename
else:
yield f"⚠️ Arquivo inválido: {uploaded_file}"
yield f"✅ {len(image_files)} imagens copiadas."
# --- ETAPA 2: Verifica imagens e gera legendas ---
image_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.webp')
image_files = [f for f in os.listdir(dataset_dir) if f.lower().endswith(image_extensions)]
if len(image_files) == 0:
yield "❌ Nenhuma imagem encontrada. Envie arquivos válidos."
return
yield f"📝 Aplicando legenda base: '{full_description}'"
for img_name in image_files:
txt_path = os.path.join(dataset_dir, os.path.splitext(img_name)[0] + ".txt")
if not os.path.exists(txt_path):
with open(txt_path, "w", encoding="utf-8") as f:
f.write(full_description)
yield "🔍 Legendas aplicadas com sucesso!"
# --- ETAPA 3: Treinamento ---
output_dir = "lora-output"
os.makedirs(output_dir, exist_ok=True)
cmd = [
"python", "train_lora.py",
"--dataset_dir", dataset_dir,
"--model_name", model_name,
"--lora_rank", str(lora_rank),
"--learning_rate", str(learning_rate),
"--num_epochs", str(num_epochs),
"--batch_size", "1",
"--output_dir", output_dir
]
if hub_token:
os.environ["HF_TOKEN"] = hub_token
cmd.append("--push_to_hub")
cmd.append("--hub_model_id")
cmd.append(f"{concept_name}-lora")
yield "🔥 Iniciando treinamento LoRA... Isso pode levar alguns minutos."
try:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
bufsize=1,
encoding='utf-8'
)
log_output = ""
for line in process.stdout:
log_output += line
if "loss" in line.lower() or "epoch" in line.lower():
yield f"📊 {line.strip()}"
process.wait()
if process.returncode == 0:
yield f"""
🎉 TREINAMENTO CONCLUÍDO!
🔹 Use no prompt: `photo of {concept_name} in the forest`
🔹 Modelo salvo em: `{output_dir}`
{'🔹 Publicado no Hub!' if hub_token else ''}
"""
else:
yield f"❌ Treinamento falhou. Código: {process.returncode}\nLogs:\n{log_output[-1000:]}"
except Exception as e:
yield f"💥 Erro ao executar treinamento: {str(e)}"
# --- Interface Gradio ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎨 Treinador de LoRA - Hugging Face")
gr.Markdown("Treine seu próprio modelo com nome, descrição e imagens!")
with gr.Row():
input_type = gr.Radio(
["Upload de ZIP", "Selecionar várias imagens"],
label="Tipo de Entrada",
value="Upload de ZIP"
)
with gr.Row():
dataset_input = gr.File(
label="📤 Envie seu ZIP ou imagens",
file_types=[".zip", ".jpg", ".jpeg", ".png", ".bmp", ".webp"],
file_count="multiple"
)
gr.Markdown("### 🔖 Identidade do Personagem/Conceito")
with gr.Row():
concept_name = gr.Textbox(
label="Nome do Conceito (ex: brenda)",
placeholder="Ex: brenda, cyborg_x, estilo_pintura",
value=""
)
with gr.Row():
description = gr.Textbox(
label="Descrição Base (ex: woman, curly hair, realistic)",
placeholder="Ex: young black woman, warm smile, detailed face",
lines=2
)
gr.Markdown("### ⚙️ Configurações do Treinamento")
with gr.Row():
model_name = gr.Dropdown(
["runwayml/stable-diffusion-v1-5"],
value="runwayml/stable-diffusion-v1-5",
label="Modelo Base"
)
lora_rank = gr.Slider(4, 64, value=4, step=4, label="LoRA Rank")
learning_rate = gr.Number(value=1e-4, label="Taxa de Aprendizado")
num_epochs = gr.Slider(1, 30, value=10, step=1, label="Épocas")
hub_token = gr.Textbox(label="🔐 Token do Hugging Face (opcional)", type="password")
btn = gr.Button("🚀 Iniciar Treinamento", variant="primary")
output = gr.Textbox(label="📦 Logs e Resultado", lines=12)
btn.click(
train_lora_interface,
inputs=[
dataset_input, input_type, model_name, lora_rank,
learning_rate, num_epochs, hub_token, concept_name, description
],
outputs=output
)
# Ativa suporte a yield
demo.queue()
if __name__ == "__main__":
demo.launch()