Allex21 commited on
Commit
f98cb17
·
verified ·
1 Parent(s): 0d6cddb

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +52 -0
main.py CHANGED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import shutil
4
+ import subprocess
5
+
6
+ UPLOAD_DIR = "training_images"
7
+ OUTPUT_DIR = "lora_output"
8
+
9
+ def train_lora(images, learning_rate, num_epochs, rank):
10
+ if os.path.exists(UPLOAD_DIR):
11
+ shutil.rmtree(UPLOAD_DIR)
12
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
13
+
14
+ if os.path.exists(OUTPUT_DIR):
15
+ shutil.rmtree(OUTPUT_DIR)
16
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
17
+
18
+ for idx, img in enumerate(images):
19
+ img.save(os.path.join(UPLOAD_DIR, f"image_{idx}.png"))
20
+
21
+ cmd = [
22
+ "python", "train_lora.py",
23
+ "--images_dir", UPLOAD_DIR,
24
+ "--output_dir", OUTPUT_DIR,
25
+ "--learning_rate", str(learning_rate),
26
+ "--num_epochs", str(num_epochs),
27
+ "--rank", str(rank),
28
+ ]
29
+ result = subprocess.run(cmd, capture_output=True, text=True)
30
+
31
+ output_file = os.path.join(OUTPUT_DIR, "lora.safetensors")
32
+ if os.path.exists(output_file):
33
+ return f"✅ Treinamento finalizado!\nModelo salvo em: {output_file}\n\nLogs:\n{result.stdout}"
34
+ else:
35
+ return f"❌ Erro no treinamento:\n{result.stderr}"
36
+
37
+ # Interface Gradio
38
+ with gr.Blocks() as demo:
39
+ gr.Markdown("# 🖼️ Criador & Treinador de LoRA")
40
+ with gr.Row():
41
+ image_input = gr.File(file_types=[".png", ".jpg", ".jpeg"], file_types_display="images", file_count="multiple", label="Envie suas imagens (10–50)")
42
+ with gr.Row():
43
+ learning_rate = gr.Number(value=1e-4, label="Learning Rate")
44
+ num_epochs = gr.Number(value=10, label="Número de Epochs")
45
+ rank = gr.Number(value=4, label="Rank do LoRA")
46
+ with gr.Row():
47
+ train_button = gr.Button("🚀 Treinar LoRA")
48
+ output_text = gr.Textbox(label="Saída", lines=15)
49
+
50
+ train_button.click(fn=train_lora, inputs=[image_input, learning_rate, num_epochs, rank], outputs=output_text)
51
+
52
+ # 👇 IMPORTANTE: Não usar demo.launch(), apenas expor a variável demo