|
|
|
|
|
|
|
|
import subprocess |
|
|
import sys |
|
|
import os |
|
|
import shutil |
|
|
|
|
|
def install_packages(): |
|
|
print("🔧 Forçando reinstalação de diffusers e huggingface_hub...") |
|
|
try: |
|
|
subprocess.check_call([ |
|
|
sys.executable, "-m", "pip", "install", |
|
|
"--force-reinstall", |
|
|
"diffusers>=0.26.0", |
|
|
"huggingface-hub>=0.20.0", |
|
|
"accelerate", |
|
|
"peft", |
|
|
"torch==2.3.0", |
|
|
"transformers==4.40.0" |
|
|
]) |
|
|
print("✅ Pacotes essenciais instalados!") |
|
|
except Exception as e: |
|
|
print(f"❌ Falha na instalação forçada: {e}") |
|
|
|
|
|
|
|
|
if not os.path.exists("/tmp/packages_installed"): |
|
|
install_packages() |
|
|
with open("/tmp/packages_installed", "w") as f: |
|
|
f.write("ok") |
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
from preprocess import process_dataset |
|
|
import subprocess |
|
|
import zipfile |
|
|
import time |
|
|
|
|
|
def train_lora_interface( |
|
|
dataset_input, input_type, model_name, lora_rank, learning_rate, |
|
|
num_epochs, hub_token, concept_name, description |
|
|
): |
|
|
if not dataset_input: |
|
|
yield "❌ Por favor, envie um ZIP ou selecione imagens." |
|
|
return |
|
|
if not concept_name.strip(): |
|
|
yield "❌ Por favor, defina um nome para o conceito (ex: brenda)." |
|
|
return |
|
|
if not description.strip(): |
|
|
yield "❌ Por favor, adicione uma descrição base (ex: mulher, 30 anos, cabelo cacheado)." |
|
|
return |
|
|
|
|
|
concept_name = concept_name.strip().replace(" ", "_") |
|
|
full_description = f"{description.strip()}, {concept_name}" |
|
|
|
|
|
yield f"🏷️ Treinando conceito: '{concept_name}' → Prompt: [photo of {concept_name}]" |
|
|
|
|
|
|
|
|
dataset_dir = "processed_data" |
|
|
os.makedirs(dataset_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
for item in os.listdir(dataset_dir): |
|
|
item_path = os.path.join(dataset_dir, item) |
|
|
try: |
|
|
if os.path.isfile(item_path) or os.path.islink(item_path): |
|
|
os.unlink(item_path) |
|
|
elif os.path.isdir(item_path): |
|
|
shutil.rmtree(item_path) |
|
|
except Exception as e: |
|
|
yield f"⚠️ Erro ao limpar: {e}" |
|
|
|
|
|
|
|
|
if input_type == "Upload de ZIP": |
|
|
zip_file = dataset_input[0] if isinstance(dataset_input, list) else dataset_input |
|
|
|
|
|
if not zipfile.is_zipfile(zip_file): |
|
|
yield "❌ Arquivo não é um ZIP válido." |
|
|
return |
|
|
|
|
|
yield "📦 Descompactando ZIP..." |
|
|
with zipfile.ZipFile(zip_file, 'r') as z: |
|
|
z.extractall(dataset_dir) |
|
|
yield f"✅ ZIP extraído! {len(z.namelist())} arquivos." |
|
|
|
|
|
else: |
|
|
image_files = dataset_input if isinstance(dataset_input, list) else [dataset_input] |
|
|
yield f"🖼️ Recebidas {len(image_files)} imagens. Copiando..." |
|
|
|
|
|
for uploaded_file in image_files: |
|
|
if hasattr(uploaded_file, 'name'): |
|
|
src_path = uploaded_file.name |
|
|
filename = os.path.basename(src_path) |
|
|
dest_path = os.path.join(dataset_dir, filename) |
|
|
shutil.copy(src_path, dest_path) |
|
|
else: |
|
|
yield f"⚠️ Arquivo inválido: {uploaded_file}" |
|
|
|
|
|
yield f"✅ {len(image_files)} imagens copiadas." |
|
|
|
|
|
|
|
|
image_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.webp') |
|
|
image_files = [f for f in os.listdir(dataset_dir) if f.lower().endswith(image_extensions)] |
|
|
|
|
|
if len(image_files) == 0: |
|
|
yield "❌ Nenhuma imagem encontrada. Envie arquivos válidos." |
|
|
return |
|
|
|
|
|
yield f"📝 Aplicando legenda base: '{full_description}'" |
|
|
|
|
|
for img_name in image_files: |
|
|
txt_path = os.path.join(dataset_dir, os.path.splitext(img_name)[0] + ".txt") |
|
|
if not os.path.exists(txt_path): |
|
|
with open(txt_path, "w", encoding="utf-8") as f: |
|
|
f.write(full_description) |
|
|
|
|
|
yield "🔍 Legendas aplicadas com sucesso!" |
|
|
|
|
|
|
|
|
output_dir = "lora-output" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
cmd = [ |
|
|
"python", "train_lora.py", |
|
|
"--dataset_dir", dataset_dir, |
|
|
"--model_name", model_name, |
|
|
"--lora_rank", str(lora_rank), |
|
|
"--learning_rate", str(learning_rate), |
|
|
"--num_epochs", str(num_epochs), |
|
|
"--batch_size", "1", |
|
|
"--output_dir", output_dir |
|
|
] |
|
|
|
|
|
if hub_token: |
|
|
os.environ["HF_TOKEN"] = hub_token |
|
|
cmd.append("--push_to_hub") |
|
|
cmd.append("--hub_model_id") |
|
|
cmd.append(f"{concept_name}-lora") |
|
|
|
|
|
yield "🔥 Iniciando treinamento LoRA... Isso pode levar alguns minutos." |
|
|
|
|
|
try: |
|
|
process = subprocess.Popen( |
|
|
cmd, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.STDOUT, |
|
|
universal_newlines=True, |
|
|
bufsize=1, |
|
|
encoding='utf-8' |
|
|
) |
|
|
|
|
|
log_output = "" |
|
|
for line in process.stdout: |
|
|
log_output += line |
|
|
if "loss" in line.lower() or "epoch" in line.lower(): |
|
|
yield f"📊 {line.strip()}" |
|
|
|
|
|
process.wait() |
|
|
|
|
|
if process.returncode == 0: |
|
|
yield f""" |
|
|
🎉 TREINAMENTO CONCLUÍDO! |
|
|
|
|
|
🔹 Use no prompt: `photo of {concept_name} in the forest` |
|
|
🔹 Modelo salvo em: `{output_dir}` |
|
|
{'🔹 Publicado no Hub!' if hub_token else ''} |
|
|
""" |
|
|
else: |
|
|
yield f"❌ Treinamento falhou. Código: {process.returncode}\nLogs:\n{log_output[-1000:]}" |
|
|
|
|
|
except Exception as e: |
|
|
yield f"💥 Erro ao executar treinamento: {str(e)}" |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# 🎨 Treinador de LoRA - Hugging Face") |
|
|
gr.Markdown("Treine seu próprio modelo com nome, descrição e imagens!") |
|
|
|
|
|
with gr.Row(): |
|
|
input_type = gr.Radio( |
|
|
["Upload de ZIP", "Selecionar várias imagens"], |
|
|
label="Tipo de Entrada", |
|
|
value="Upload de ZIP" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
dataset_input = gr.File( |
|
|
label="📤 Envie seu ZIP ou imagens", |
|
|
file_types=[".zip", ".jpg", ".jpeg", ".png", ".bmp", ".webp"], |
|
|
file_count="multiple" |
|
|
) |
|
|
|
|
|
gr.Markdown("### 🔖 Identidade do Personagem/Conceito") |
|
|
with gr.Row(): |
|
|
concept_name = gr.Textbox( |
|
|
label="Nome do Conceito (ex: brenda)", |
|
|
placeholder="Ex: brenda, cyborg_x, estilo_pintura", |
|
|
value="" |
|
|
) |
|
|
with gr.Row(): |
|
|
description = gr.Textbox( |
|
|
label="Descrição Base (ex: woman, curly hair, realistic)", |
|
|
placeholder="Ex: young black woman, warm smile, detailed face", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
gr.Markdown("### ⚙️ Configurações do Treinamento") |
|
|
with gr.Row(): |
|
|
model_name = gr.Dropdown( |
|
|
["runwayml/stable-diffusion-v1-5"], |
|
|
value="runwayml/stable-diffusion-v1-5", |
|
|
label="Modelo Base" |
|
|
) |
|
|
lora_rank = gr.Slider(4, 64, value=4, step=4, label="LoRA Rank") |
|
|
learning_rate = gr.Number(value=1e-4, label="Taxa de Aprendizado") |
|
|
num_epochs = gr.Slider(1, 30, value=10, step=1, label="Épocas") |
|
|
|
|
|
hub_token = gr.Textbox(label="🔐 Token do Hugging Face (opcional)", type="password") |
|
|
|
|
|
btn = gr.Button("🚀 Iniciar Treinamento", variant="primary") |
|
|
output = gr.Textbox(label="📦 Logs e Resultado", lines=12) |
|
|
|
|
|
btn.click( |
|
|
train_lora_interface, |
|
|
inputs=[ |
|
|
dataset_input, input_type, model_name, lora_rank, |
|
|
learning_rate, num_epochs, hub_token, concept_name, description |
|
|
], |
|
|
outputs=output |
|
|
) |
|
|
|
|
|
|
|
|
demo.queue() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |