Allex21 commited on
Commit
a55186d
·
verified ·
1 Parent(s): 80f2676

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +56 -0
main.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import shutil
4
+ import subprocess
5
+
6
+ UPLOAD_DIR = "training_images"
7
+ OUTPUT_DIR = "lora_output"
8
+
9
+ def train_lora(images, learning_rate, num_epochs, rank):
10
+ # Limpa pastas antigas
11
+ if os.path.exists(UPLOAD_DIR):
12
+ shutil.rmtree(UPLOAD_DIR)
13
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
14
+
15
+ if os.path.exists(OUTPUT_DIR):
16
+ shutil.rmtree(OUTPUT_DIR)
17
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
18
+
19
+ # Salva imagens enviadas
20
+ for idx, img in enumerate(images):
21
+ img.save(os.path.join(UPLOAD_DIR, f"image_{idx}.png"))
22
+
23
+ # Executa o script de treinamento
24
+ cmd = [
25
+ "python", "train_lora.py",
26
+ "--images_dir", UPLOAD_DIR,
27
+ "--output_dir", OUTPUT_DIR,
28
+ "--learning_rate", str(learning_rate),
29
+ "--num_epochs", str(num_epochs),
30
+ "--rank", str(rank),
31
+ ]
32
+ result = subprocess.run(cmd, capture_output=True, text=True)
33
+
34
+ # Retorna logs e link para baixar o LoRA
35
+ output_file = os.path.join(OUTPUT_DIR, "lora.safetensors")
36
+ if os.path.exists(output_file):
37
+ return f"✅ Treinamento finalizado!\nModelo salvo em: {output_file}\n\nLogs:\n{result.stdout}"
38
+ else:
39
+ return f"❌ Erro no treinamento:\n{result.stderr}"
40
+
41
+ # Interface Gradio
42
+ with gr.Blocks() as demo:
43
+ gr.Markdown("# 🖼️ Criador & Treinador de LoRA")
44
+ with gr.Row():
45
+ image_input = gr.File(file_types=[".png", ".jpg", ".jpeg"], file_types_display="images", file_count="multiple", label="Envie suas imagens (10–50)")
46
+ with gr.Row():
47
+ learning_rate = gr.Number(value=1e-4, label="Learning Rate")
48
+ num_epochs = gr.Number(value=10, label="Número de Epochs")
49
+ rank = gr.Number(value=4, label="Rank do LoRA")
50
+ with gr.Row():
51
+ train_button = gr.Button("🚀 Treinar LoRA")
52
+ output_text = gr.Textbox(label="Saída", lines=15)
53
+
54
+ train_button.click(fn=train_lora, inputs=[image_input, learning_rate, num_epochs, rank], outputs=output_text)
55
+
56
+ demo.launch()