|
|
|
|
|
import os |
|
|
import gradio as gr |
|
|
from preprocess import process_dataset |
|
|
import subprocess |
|
|
import zipfile |
|
|
import shutil |
|
|
import time |
|
|
|
|
|
def train_lora_interface( |
|
|
dataset_input, input_type, model_name, lora_rank, learning_rate, |
|
|
num_epochs, hub_token, concept_name, description |
|
|
): |
|
|
if not dataset_input: |
|
|
yield "❌ Por favor, envie um ZIP ou selecione imagens." |
|
|
return |
|
|
if not concept_name.strip(): |
|
|
yield "❌ Por favor, defina um nome para o conceito (ex: brenda)." |
|
|
return |
|
|
if not description.strip(): |
|
|
yield "❌ Por favor, adicione uma descrição base." |
|
|
return |
|
|
|
|
|
concept_name = concept_name.strip().replace(" ", "_") |
|
|
full_description = f"{description.strip()}, {concept_name}" |
|
|
|
|
|
yield f"🏷️ Treinando: '{concept_name}'" |
|
|
|
|
|
dataset_dir = "processed_data" |
|
|
os.makedirs(dataset_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
for item in os.listdir(dataset_dir): |
|
|
item_path = os.path.join(dataset_dir, item) |
|
|
try: |
|
|
if os.path.isfile(item_path) or os.path.islink(item_path): |
|
|
os.unlink(item_path) |
|
|
elif os.path.isdir(item_path): |
|
|
shutil.rmtree(item_path) |
|
|
except Exception as e: |
|
|
yield f"⚠️ Erro ao limpar: {e}" |
|
|
|
|
|
|
|
|
if input_type == "Upload de ZIP": |
|
|
zip_file = dataset_input[0] if isinstance(dataset_input, list) else dataset_input |
|
|
|
|
|
if not zipfile.is_zipfile(zip_file): |
|
|
yield "❌ Arquivo não é um ZIP válido." |
|
|
return |
|
|
|
|
|
yield "📦 Descompactando..." |
|
|
with zipfile.ZipFile(zip_file, 'r') as z: |
|
|
z.extractall(dataset_dir) |
|
|
yield f"✅ ZIP extraído! {len(z.namelist())} arquivos." |
|
|
|
|
|
else: |
|
|
image_files = dataset_input if isinstance(dataset_input, list) else [dataset_input] |
|
|
yield f"🖼️ Copiando {len(image_files)} imagens..." |
|
|
|
|
|
for uploaded_file in image_files: |
|
|
if hasattr(uploaded_file, 'name'): |
|
|
src = uploaded_file.name |
|
|
dest = os.path.join(dataset_dir, os.path.basename(src)) |
|
|
shutil.copy(src, dest) |
|
|
|
|
|
yield f"✅ {len(image_files)} imagens copiadas." |
|
|
|
|
|
|
|
|
exts = ('.png', '.jpg', '.jpeg', '.bmp', '.webp') |
|
|
images = [f for f in os.listdir(dataset_dir) if f.lower().endswith(exts)] |
|
|
|
|
|
if len(images) == 0: |
|
|
yield "❌ Nenhuma imagem encontrada!" |
|
|
return |
|
|
|
|
|
yield f"📝 Aplicando legenda: '{full_description}'" |
|
|
|
|
|
for img in images: |
|
|
txt = os.path.join(dataset_dir, os.path.splitext(img)[0] + ".txt") |
|
|
if not os.path.exists(txt): |
|
|
with open(txt, "w", encoding="utf-8") as f: |
|
|
f.write(full_description) |
|
|
|
|
|
yield "🔍 Legendas prontas!" |
|
|
|
|
|
|
|
|
output_dir = "lora-output" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
cmd = [ |
|
|
"python", "train_lora.py", |
|
|
"--dataset_dir", dataset_dir, |
|
|
"--model_name", model_name, |
|
|
"--lora_rank", str(lora_rank), |
|
|
"--learning_rate", str(learning_rate), |
|
|
"--num_epochs", str(num_epochs), |
|
|
"--batch_size", "1", |
|
|
"--output_dir", output_dir |
|
|
] |
|
|
|
|
|
if hub_token: |
|
|
os.environ["HF_TOKEN"] = hub_token |
|
|
cmd += ["--push_to_hub", "--hub_model_id", f"{concept_name}-lora"] |
|
|
|
|
|
yield "🔥 Iniciando treinamento..." |
|
|
|
|
|
try: |
|
|
process = subprocess.Popen( |
|
|
cmd, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.STDOUT, |
|
|
universal_newlines=True, |
|
|
bufsize=1, |
|
|
encoding='utf-8' |
|
|
) |
|
|
|
|
|
log_output = "" |
|
|
for line in process.stdout: |
|
|
log_output += line |
|
|
if "loss" in line.lower() or "epoch" in line.lower(): |
|
|
yield f"📊 {line.strip()}" |
|
|
|
|
|
process.wait() |
|
|
|
|
|
if process.returncode == 0: |
|
|
yield f""" |
|
|
🎉 SUCESSO! |
|
|
|
|
|
🔹 Use no prompt: `photo of {concept_name} in the forest` |
|
|
🔹 Modelo salvo em: `{output_dir}` |
|
|
{'🔹 Publicado no Hub!' if hub_token else ''} |
|
|
""" |
|
|
else: |
|
|
yield f"❌ Falha no treinamento. Código: {process.returncode}\nLogs:\n{log_output[-1000:]}" |
|
|
|
|
|
except Exception as e: |
|
|
yield f"💥 Erro: {str(e)}" |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# 🎨 Treinador de LoRA - Hugging Face") |
|
|
gr.Markdown("Treine personagens, estilos ou objetos personalizados.") |
|
|
|
|
|
with gr.Row(): |
|
|
input_type = gr.Radio( |
|
|
["Upload de ZIP", "Selecionar várias imagens"], |
|
|
label="Tipo de Entrada", |
|
|
value="Upload de ZIP" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
dataset_input = gr.File( |
|
|
label="📤 Envie seu ZIP ou imagens", |
|
|
file_types=[".zip", ".jpg", ".jpeg", ".png", ".bmp", ".webp"], |
|
|
file_count="multiple" |
|
|
) |
|
|
|
|
|
gr.Markdown("### 🔖 Identidade do Personagem") |
|
|
with gr.Row(): |
|
|
concept_name = gr.Textbox( |
|
|
label="Nome do Conceito (ex: brenda)", |
|
|
placeholder="Ex: brenda, cyborg_x", |
|
|
value="" |
|
|
) |
|
|
with gr.Row(): |
|
|
description = gr.Textbox( |
|
|
label="Descrição Base (ex: woman, curly hair)", |
|
|
placeholder="Ex: young black woman, realistic style", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
gr.Markdown("### ⚙️ Configurações") |
|
|
with gr.Row(): |
|
|
model_name = gr.Dropdown( |
|
|
["runwayml/stable-diffusion-v1-5"], |
|
|
value="runwayml/stable-diffusion-v1-5", |
|
|
label="Modelo Base" |
|
|
) |
|
|
lora_rank = gr.Slider(4, 64, value=4, step=4, label="LoRA Rank") |
|
|
learning_rate = gr.Number(value=1e-4, label="Taxa de Aprendizado") |
|
|
num_epochs = gr.Slider(1, 30, value=10, step=1, label="Épocas") |
|
|
|
|
|
hub_token = gr.Textbox(label="🔐 Token do HF (opcional)", type="password") |
|
|
|
|
|
btn = gr.Button("🚀 Iniciar Treinamento", variant="primary") |
|
|
output = gr.Textbox(label="📦 Logs", lines=12) |
|
|
|
|
|
btn.click( |
|
|
train_lora_interface, |
|
|
inputs=[ |
|
|
dataset_input, input_type, model_name, lora_rank, |
|
|
learning_rate, num_epochs, hub_token, concept_name, description |
|
|
], |
|
|
outputs=output |
|
|
) |
|
|
|
|
|
demo.queue() |
|
|
if __name__ == "__main__": |
|
|
demo.launch() |