|
|
import gradio as gr |
|
|
import os |
|
|
import shutil |
|
|
import subprocess |
|
|
|
|
|
UPLOAD_DIR = "training_images" |
|
|
OUTPUT_DIR = "lora_output" |
|
|
|
|
|
def train_lora(images, learning_rate, num_epochs, rank): |
|
|
if os.path.exists(UPLOAD_DIR): |
|
|
shutil.rmtree(UPLOAD_DIR) |
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
|
|
|
if os.path.exists(OUTPUT_DIR): |
|
|
shutil.rmtree(OUTPUT_DIR) |
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
for idx, img in enumerate(images): |
|
|
img.save(os.path.join(UPLOAD_DIR, f"image_{idx}.png")) |
|
|
|
|
|
cmd = [ |
|
|
"python", "train_lora.py", |
|
|
"--images_dir", UPLOAD_DIR, |
|
|
"--output_dir", OUTPUT_DIR, |
|
|
"--learning_rate", str(learning_rate), |
|
|
"--num_epochs", str(num_epochs), |
|
|
"--rank", str(rank), |
|
|
] |
|
|
result = subprocess.run(cmd, capture_output=True, text=True) |
|
|
|
|
|
output_file = os.path.join(OUTPUT_DIR, "lora.safetensors") |
|
|
if os.path.exists(output_file): |
|
|
return f"✅ Treinamento finalizado!\nModelo salvo em: {output_file}\n\nLogs:\n{result.stdout}" |
|
|
else: |
|
|
return f"❌ Erro no treinamento:\n{result.stderr}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# 🖼️ Criador & Treinador de LoRA") |
|
|
with gr.Row(): |
|
|
image_input = gr.File( |
|
|
file_types=[".png", ".jpg", ".jpeg"], |
|
|
file_types_display="images", |
|
|
file_count="multiple", |
|
|
label="Envie suas imagens (10–50)" |
|
|
) |
|
|
with gr.Row(): |
|
|
learning_rate = gr.Number(value=1e-4, label="Learning Rate") |
|
|
num_epochs = gr.Number(value=10, label="Número de Epochs") |
|
|
rank = gr.Number(value=4, label="Rank do LoRA") |
|
|
with gr.Row(): |
|
|
train_button = gr.Button("🚀 Treinar LoRA") |
|
|
output_text = gr.Textbox(label="Saída", lines=15) |
|
|
|
|
|
train_button.click( |
|
|
fn=train_lora, |
|
|
inputs=[image_input, learning_rate, num_epochs, rank], |
|
|
outputs=output_text |
|
|
) |
|
|
|
|
|
|
|
|
demo |