# app.py import os import gradio as gr from preprocess import process_dataset import subprocess import zipfile import shutil import time def train_lora_interface( dataset_input, input_type, model_name, lora_rank, learning_rate, num_epochs, hub_token, concept_name, description ): if not dataset_input: yield "❌ Por favor, envie um ZIP ou selecione imagens." return if not concept_name.strip(): yield "❌ Por favor, defina um nome para o conceito (ex: brenda)." return if not description.strip(): yield "❌ Por favor, adicione uma descrição base." return concept_name = concept_name.strip().replace(" ", "_") full_description = f"{description.strip()}, {concept_name}" yield f"🏷️ Treinando: '{concept_name}'" dataset_dir = "processed_data" os.makedirs(dataset_dir, exist_ok=True) # Limpa pasta anterior for item in os.listdir(dataset_dir): item_path = os.path.join(dataset_dir, item) try: if os.path.isfile(item_path) or os.path.islink(item_path): os.unlink(item_path) elif os.path.isdir(item_path): shutil.rmtree(item_path) except Exception as e: yield f"⚠️ Erro ao limpar: {e}" # --- ETAPA 1: Processar entrada --- if input_type == "Upload de ZIP": zip_file = dataset_input[0] if isinstance(dataset_input, list) else dataset_input if not zipfile.is_zipfile(zip_file): yield "❌ Arquivo não é um ZIP válido." return yield "📦 Descompactando..." with zipfile.ZipFile(zip_file, 'r') as z: z.extractall(dataset_dir) yield f"✅ ZIP extraído! {len(z.namelist())} arquivos." else: image_files = dataset_input if isinstance(dataset_input, list) else [dataset_input] yield f"🖼️ Copiando {len(image_files)} imagens..." for uploaded_file in image_files: if hasattr(uploaded_file, 'name'): src = uploaded_file.name dest = os.path.join(dataset_dir, os.path.basename(src)) shutil.copy(src, dest) yield f"✅ {len(image_files)} imagens copiadas." # --- ETAPA 2: Gera legendas --- exts = ('.png', '.jpg', '.jpeg', '.bmp', '.webp') images = [f for f in os.listdir(dataset_dir) if f.lower().endswith(exts)] if len(images) == 0: yield "❌ Nenhuma imagem encontrada!" return yield f"📝 Aplicando legenda: '{full_description}'" for img in images: txt = os.path.join(dataset_dir, os.path.splitext(img)[0] + ".txt") if not os.path.exists(txt): with open(txt, "w", encoding="utf-8") as f: f.write(full_description) yield "🔍 Legendas prontas!" # --- ETAPA 3: Treinamento --- output_dir = "lora-output" os.makedirs(output_dir, exist_ok=True) cmd = [ "python", "train_lora.py", "--dataset_dir", dataset_dir, "--model_name", model_name, "--lora_rank", str(lora_rank), "--learning_rate", str(learning_rate), "--num_epochs", str(num_epochs), "--batch_size", "1", "--output_dir", output_dir ] if hub_token: os.environ["HF_TOKEN"] = hub_token cmd += ["--push_to_hub", "--hub_model_id", f"{concept_name}-lora"] yield "🔥 Iniciando treinamento..." try: process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, bufsize=1, encoding='utf-8' ) log_output = "" for line in process.stdout: log_output += line if "loss" in line.lower() or "epoch" in line.lower(): yield f"📊 {line.strip()}" process.wait() if process.returncode == 0: yield f""" 🎉 SUCESSO! 🔹 Use no prompt: `photo of {concept_name} in the forest` 🔹 Modelo salvo em: `{output_dir}` {'🔹 Publicado no Hub!' if hub_token else ''} """ else: yield f"❌ Falha no treinamento. Código: {process.returncode}\nLogs:\n{log_output[-1000:]}" except Exception as e: yield f"💥 Erro: {str(e)}" # --- Interface --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎨 Treinador de LoRA - Hugging Face") gr.Markdown("Treine personagens, estilos ou objetos personalizados.") with gr.Row(): input_type = gr.Radio( ["Upload de ZIP", "Selecionar várias imagens"], label="Tipo de Entrada", value="Upload de ZIP" ) with gr.Row(): dataset_input = gr.File( label="📤 Envie seu ZIP ou imagens", file_types=[".zip", ".jpg", ".jpeg", ".png", ".bmp", ".webp"], file_count="multiple" ) gr.Markdown("### 🔖 Identidade do Personagem") with gr.Row(): concept_name = gr.Textbox( label="Nome do Conceito (ex: brenda)", placeholder="Ex: brenda, cyborg_x", value="" ) with gr.Row(): description = gr.Textbox( label="Descrição Base (ex: woman, curly hair)", placeholder="Ex: young black woman, realistic style", lines=2 ) gr.Markdown("### ⚙️ Configurações") with gr.Row(): model_name = gr.Dropdown( ["runwayml/stable-diffusion-v1-5"], value="runwayml/stable-diffusion-v1-5", label="Modelo Base" ) lora_rank = gr.Slider(4, 64, value=4, step=4, label="LoRA Rank") learning_rate = gr.Number(value=1e-4, label="Taxa de Aprendizado") num_epochs = gr.Slider(1, 30, value=10, step=1, label="Épocas") hub_token = gr.Textbox(label="🔐 Token do HF (opcional)", type="password") btn = gr.Button("🚀 Iniciar Treinamento", variant="primary") output = gr.Textbox(label="📦 Logs", lines=12) btn.click( train_lora_interface, inputs=[ dataset_input, input_type, model_name, lora_rank, learning_rate, num_epochs, hub_token, concept_name, description ], outputs=output ) demo.queue() if __name__ == "__main__": demo.launch()