File size: 6,371 Bytes
9275790 4cafbc8 09dea70 81d438f 4cafbc8 594c14a 09dea70 1ab5ebf 594c14a 81d438f 594c14a 09dea70 1ab5ebf 594c14a 09dea70 81d438f 594c14a 09dea70 594c14a 09dea70 1ab5ebf bba7e77 09dea70 1ab5ebf 09dea70 bba7e77 81d438f 1ab5ebf 09dea70 1ab5ebf 09dea70 81d438f 1ab5ebf 81d438f 1ab5ebf 81d438f 1ab5ebf bba7e77 1ab5ebf 81d438f 09dea70 81d438f 594c14a 81d438f 594c14a 09dea70 81d438f 09dea70 1ab5ebf 09dea70 81d438f 09dea70 81d438f 1ab5ebf 09dea70 4cafbc8 9275790 4cafbc8 9275790 4cafbc8 9275790 4cafbc8 9275790 09dea70 9275790 4cafbc8 9275790 594c14a 81d438f 594c14a 1ab5ebf 594c14a 9275790 81d438f 09dea70 9275790 81d438f 9275790 81d438f 9275790 81d438f 9275790 09dea70 1ab5ebf 594c14a 81d438f 594c14a 1ab5ebf 81d438f 594c14a 81d438f 594c14a 09dea70 9275790 81d438f 9275790 4cafbc8 9275790 4cafbc8 9275790 4cafbc8 9275790 81d438f 9275790 81d438f 9275790 594c14a 4cafbc8 9275790 4cafbc8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
# app.py
import os
import gradio as gr
from preprocess import process_dataset
import subprocess
import zipfile
import shutil
import time
def train_lora_interface(
dataset_input, input_type, model_name, lora_rank, learning_rate,
num_epochs, hub_token, concept_name, description
):
if not dataset_input:
yield "❌ Por favor, envie um ZIP ou selecione imagens."
return
if not concept_name.strip():
yield "❌ Por favor, defina um nome para o conceito (ex: brenda)."
return
if not description.strip():
yield "❌ Por favor, adicione uma descrição base."
return
concept_name = concept_name.strip().replace(" ", "_")
full_description = f"{description.strip()}, {concept_name}"
yield f"🏷️ Treinando: '{concept_name}'"
dataset_dir = "processed_data"
os.makedirs(dataset_dir, exist_ok=True)
# Limpa pasta anterior
for item in os.listdir(dataset_dir):
item_path = os.path.join(dataset_dir, item)
try:
if os.path.isfile(item_path) or os.path.islink(item_path):
os.unlink(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
except Exception as e:
yield f"⚠️ Erro ao limpar: {e}"
# --- ETAPA 1: Processar entrada ---
if input_type == "Upload de ZIP":
zip_file = dataset_input[0] if isinstance(dataset_input, list) else dataset_input
if not zipfile.is_zipfile(zip_file):
yield "❌ Arquivo não é um ZIP válido."
return
yield "📦 Descompactando..."
with zipfile.ZipFile(zip_file, 'r') as z:
z.extractall(dataset_dir)
yield f"✅ ZIP extraído! {len(z.namelist())} arquivos."
else:
image_files = dataset_input if isinstance(dataset_input, list) else [dataset_input]
yield f"🖼️ Copiando {len(image_files)} imagens..."
for uploaded_file in image_files:
if hasattr(uploaded_file, 'name'):
src = uploaded_file.name
dest = os.path.join(dataset_dir, os.path.basename(src))
shutil.copy(src, dest)
yield f"✅ {len(image_files)} imagens copiadas."
# --- ETAPA 2: Gera legendas ---
exts = ('.png', '.jpg', '.jpeg', '.bmp', '.webp')
images = [f for f in os.listdir(dataset_dir) if f.lower().endswith(exts)]
if len(images) == 0:
yield "❌ Nenhuma imagem encontrada!"
return
yield f"📝 Aplicando legenda: '{full_description}'"
for img in images:
txt = os.path.join(dataset_dir, os.path.splitext(img)[0] + ".txt")
if not os.path.exists(txt):
with open(txt, "w", encoding="utf-8") as f:
f.write(full_description)
yield "🔍 Legendas prontas!"
# --- ETAPA 3: Treinamento ---
output_dir = "lora-output"
os.makedirs(output_dir, exist_ok=True)
cmd = [
"python", "train_lora.py",
"--dataset_dir", dataset_dir,
"--model_name", model_name,
"--lora_rank", str(lora_rank),
"--learning_rate", str(learning_rate),
"--num_epochs", str(num_epochs),
"--batch_size", "1",
"--output_dir", output_dir
]
if hub_token:
os.environ["HF_TOKEN"] = hub_token
cmd += ["--push_to_hub", "--hub_model_id", f"{concept_name}-lora"]
yield "🔥 Iniciando treinamento..."
try:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
bufsize=1,
encoding='utf-8'
)
log_output = ""
for line in process.stdout:
log_output += line
if "loss" in line.lower() or "epoch" in line.lower():
yield f"📊 {line.strip()}"
process.wait()
if process.returncode == 0:
yield f"""
🎉 SUCESSO!
🔹 Use no prompt: `photo of {concept_name} in the forest`
🔹 Modelo salvo em: `{output_dir}`
{'🔹 Publicado no Hub!' if hub_token else ''}
"""
else:
yield f"❌ Falha no treinamento. Código: {process.returncode}\nLogs:\n{log_output[-1000:]}"
except Exception as e:
yield f"💥 Erro: {str(e)}"
# --- Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎨 Treinador de LoRA - Hugging Face")
gr.Markdown("Treine personagens, estilos ou objetos personalizados.")
with gr.Row():
input_type = gr.Radio(
["Upload de ZIP", "Selecionar várias imagens"],
label="Tipo de Entrada",
value="Upload de ZIP"
)
with gr.Row():
dataset_input = gr.File(
label="📤 Envie seu ZIP ou imagens",
file_types=[".zip", ".jpg", ".jpeg", ".png", ".bmp", ".webp"],
file_count="multiple"
)
gr.Markdown("### 🔖 Identidade do Personagem")
with gr.Row():
concept_name = gr.Textbox(
label="Nome do Conceito (ex: brenda)",
placeholder="Ex: brenda, cyborg_x",
value=""
)
with gr.Row():
description = gr.Textbox(
label="Descrição Base (ex: woman, curly hair)",
placeholder="Ex: young black woman, realistic style",
lines=2
)
gr.Markdown("### ⚙️ Configurações")
with gr.Row():
model_name = gr.Dropdown(
["runwayml/stable-diffusion-v1-5"],
value="runwayml/stable-diffusion-v1-5",
label="Modelo Base"
)
lora_rank = gr.Slider(4, 64, value=4, step=4, label="LoRA Rank")
learning_rate = gr.Number(value=1e-4, label="Taxa de Aprendizado")
num_epochs = gr.Slider(1, 30, value=10, step=1, label="Épocas")
hub_token = gr.Textbox(label="🔐 Token do HF (opcional)", type="password")
btn = gr.Button("🚀 Iniciar Treinamento", variant="primary")
output = gr.Textbox(label="📦 Logs", lines=12)
btn.click(
train_lora_interface,
inputs=[
dataset_input, input_type, model_name, lora_rank,
learning_rate, num_epochs, hub_token, concept_name, description
],
outputs=output
)
demo.queue()
if __name__ == "__main__":
demo.launch() |