Allex21 commited on
Commit
0d6cddb
·
verified ·
1 Parent(s): f3b067f

Update main.py

Browse files

import gradio as gr
import os
import shutil
import subprocess

UPLOAD_DIR = "training_images"
OUTPUT_DIR = "lora_output"

def train_lora(images, learning_rate, num_epochs, rank):
if os.path.exists(UPLOAD_DIR):
shutil.rmtree(UPLOAD_DIR)
os.makedirs(UPLOAD_DIR, exist_ok=True)

if os.path.exists(OUTPUT_DIR):
shutil.rmtree(OUTPUT_DIR)
os.makedirs(OUTPUT_DIR, exist_ok=True)

for idx, img in enumerate(images):
img.save(os.path.join(UPLOAD_DIR, f"image_{idx}.png"))

cmd = [
"python", "train_lora.py",
"--images_dir", UPLOAD_DIR,
"--output_dir", OUTPUT_DIR,
"--learning_rate", str(learning_rate),
"--num_epochs", str(num_epochs),
"--rank", str(rank),
]
result = subprocess.run(cmd, capture_output=True, text=True)

output_file = os.path.join(OUTPUT_DIR, "lora.safetensors")
if os.path.exists(output_file):
return f"✅ Treinamento finalizado!\nModelo salvo em: {output_file}\n\nLogs:\n{result.stdout}"
else:
return f"❌ Erro no treinamento:\n{result.stderr}"

# Interface Gradio
with gr.Blocks() as demo:
gr.Markdown("# 🖼️ Criador & Treinador de LoRA")
with gr.Row():
image_input = gr.File(file_types=[".png", ".jpg", ".jpeg"], file_types_display="images", file_count="multiple", label="Envie suas imagens (10–50)")
with gr.Row():
learning_rate = gr.Number(value=1e-4, label="Learning Rate")
num_epochs = gr.Number(value=10, label="Número de Epochs")
rank = gr.Number(value=4, label="Rank do LoRA")
with gr.Row():
train_button = gr.Button("🚀 Treinar LoRA")
output_text = gr.Textbox(label="Saída", lines=15)

train_button.click(fn=train_lora, inputs=[image_input, learning_rate, num_epochs, rank], outputs=output_text)

# 👇 IMPORTANTE: Não usar demo.launch(), apenas expor a variável demo

Files changed (1) hide show
  1. main.py +0 -56
main.py CHANGED
@@ -1,56 +0,0 @@
1
- import gradio as gr
2
- import os
3
- import shutil
4
- import subprocess
5
-
6
- UPLOAD_DIR = "training_images"
7
- OUTPUT_DIR = "lora_output"
8
-
9
- def train_lora(images, learning_rate, num_epochs, rank):
10
- # Limpa pastas antigas
11
- if os.path.exists(UPLOAD_DIR):
12
- shutil.rmtree(UPLOAD_DIR)
13
- os.makedirs(UPLOAD_DIR, exist_ok=True)
14
-
15
- if os.path.exists(OUTPUT_DIR):
16
- shutil.rmtree(OUTPUT_DIR)
17
- os.makedirs(OUTPUT_DIR, exist_ok=True)
18
-
19
- # Salva imagens enviadas
20
- for idx, img in enumerate(images):
21
- img.save(os.path.join(UPLOAD_DIR, f"image_{idx}.png"))
22
-
23
- # Executa o script de treinamento
24
- cmd = [
25
- "python", "train_lora.py",
26
- "--images_dir", UPLOAD_DIR,
27
- "--output_dir", OUTPUT_DIR,
28
- "--learning_rate", str(learning_rate),
29
- "--num_epochs", str(num_epochs),
30
- "--rank", str(rank),
31
- ]
32
- result = subprocess.run(cmd, capture_output=True, text=True)
33
-
34
- # Retorna logs e link para baixar o LoRA
35
- output_file = os.path.join(OUTPUT_DIR, "lora.safetensors")
36
- if os.path.exists(output_file):
37
- return f"✅ Treinamento finalizado!\nModelo salvo em: {output_file}\n\nLogs:\n{result.stdout}"
38
- else:
39
- return f"❌ Erro no treinamento:\n{result.stderr}"
40
-
41
- # Interface Gradio
42
- with gr.Blocks() as demo:
43
- gr.Markdown("# 🖼️ Criador & Treinador de LoRA")
44
- with gr.Row():
45
- image_input = gr.File(file_types=[".png", ".jpg", ".jpeg"], file_types_display="images", file_count="multiple", label="Envie suas imagens (10–50)")
46
- with gr.Row():
47
- learning_rate = gr.Number(value=1e-4, label="Learning Rate")
48
- num_epochs = gr.Number(value=10, label="Número de Epochs")
49
- rank = gr.Number(value=4, label="Rank do LoRA")
50
- with gr.Row():
51
- train_button = gr.Button("🚀 Treinar LoRA")
52
- output_text = gr.Textbox(label="Saída", lines=15)
53
-
54
- train_button.click(fn=train_lora, inputs=[image_input, learning_rate, num_epochs, rank], outputs=output_text)
55
-
56
- demo.launch()