Allex21 commited on
Commit
cb03193
·
verified ·
1 Parent(s): f98cb17

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import shutil
4
+ import subprocess
5
+
6
+ UPLOAD_DIR = "training_images"
7
+ OUTPUT_DIR = "lora_output"
8
+
9
+ def train_lora(images, learning_rate, num_epochs, rank):
10
+ if os.path.exists(UPLOAD_DIR):
11
+ shutil.rmtree(UPLOAD_DIR)
12
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
13
+
14
+ if os.path.exists(OUTPUT_DIR):
15
+ shutil.rmtree(OUTPUT_DIR)
16
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
17
+
18
+ for idx, img in enumerate(images):
19
+ img.save(os.path.join(UPLOAD_DIR, f"image_{idx}.png"))
20
+
21
+ cmd = [
22
+ "python", "train_lora.py",
23
+ "--images_dir", UPLOAD_DIR,
24
+ "--output_dir", OUTPUT_DIR,
25
+ "--learning_rate", str(learning_rate),
26
+ "--num_epochs", str(num_epochs),
27
+ "--rank", str(rank),
28
+ ]
29
+ result = subprocess.run(cmd, capture_output=True, text=True)
30
+
31
+ output_file = os.path.join(OUTPUT_DIR, "lora.safetensors")
32
+ if os.path.exists(output_file):
33
+ return f"✅ Treinamento finalizado!\nModelo salvo em: {output_file}\n\nLogs:\n{result.stdout}"
34
+ else:
35
+ return f"❌ Erro no treinamento:\n{result.stderr}"
36
+
37
+ with gr.Blocks() as demo:
38
+ gr.Markdown("# 🖼️ Criador & Treinador de LoRA")
39
+ with gr.Row():
40
+ image_input = gr.File(
41
+ file_types=[".png", ".jpg", ".jpeg"],
42
+ file_types_display="images",
43
+ file_count="multiple",
44
+ label="Envie suas imagens (10–50)"
45
+ )
46
+ with gr.Row():
47
+ learning_rate = gr.Number(value=1e-4, label="Learning Rate")
48
+ num_epochs = gr.Number(value=10, label="Número de Epochs")
49
+ rank = gr.Number(value=4, label="Rank do LoRA")
50
+ with gr.Row():
51
+ train_button = gr.Button("🚀 Treinar LoRA")
52
+ output_text = gr.Textbox(label="Saída", lines=15)
53
+
54
+ train_button.click(
55
+ fn=train_lora,
56
+ inputs=[image_input, learning_rate, num_epochs, rank],
57
+ outputs=output_text
58
+ )
59
+
60
+ # 👇 MUITO IMPORTANTE: apenas expor a variável demo
61
+ demo