Allex21 commited on
Commit
4cafbc8
·
verified ·
1 Parent(s): 7865b76

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from preprocess import process_dataset
4
+ import subprocess
5
+ import time
6
+
7
+ def train_lora_interface(dataset_zip, model_name, lora_rank, learning_rate, num_epochs, hub_token):
8
+ # 1. Pré-processamento
9
+ with gr.Progress() as progress:
10
+ progress(0, "Descompactando e processando dataset...")
11
+ dataset_dir = process_dataset(dataset_zip, "processed_data")
12
+
13
+ # 2. Configura treinamento
14
+ progress(0.3, "Configurando treinamento LoRA...")
15
+ output_dir = "lora-output"
16
+ os.makedirs(output_dir, exist_ok=True)
17
+
18
+ # 3. Executa treinamento
19
+ progress(0.5, "Treinando modelo (isso pode levar horas)...")
20
+ cmd = [
21
+ "python", "train_lora.py",
22
+ "--dataset_dir", dataset_dir,
23
+ "--model_name", model_name,
24
+ "--lora_rank", str(lora_rank),
25
+ "--learning_rate", str(learning_rate),
26
+ "--num_epochs", str(num_epochs),
27
+ "--output_dir", output_dir
28
+ ]
29
+
30
+ if hub_token:
31
+ os.environ["HF_TOKEN"] = hub_token
32
+ cmd.append("--push_to_hub")
33
+ cmd.append("--hub_model_id")
34
+ cmd.append("my-lora-model")
35
+
36
+ process = subprocess.Popen(
37
+ cmd,
38
+ stdout=subprocess.PIPE,
39
+ stderr=subprocess.STDOUT,
40
+ universal_newlines=True
41
+ )
42
+
43
+ logs = ""
44
+ for line in process.stdout:
45
+ logs += line
46
+ progress(0.7, f"Treinando...\n{logs[-200:]}")
47
+
48
+ # 4. Finalização
49
+ progress(0.9, "Subindo para Hugging Face Hub...")
50
+ if hub_token:
51
+ from huggingface_hub import upload_folder
52
+ upload_folder(
53
+ repo_id="my-lora-model",
54
+ folder_path=output_dir,
55
+ token=hub_token
56
+ )
57
+
58
+ progress(1.0, "Treinamento concluído com sucesso!")
59
+ return f"Modelo salvo em: {output_dir}\nLogs: {logs[-500:]}"
60
+
61
+ # Interface Gradio
62
+ with gr.Blocks(title="LoRA Trainer - Hugging Face") as demo:
63
+ gr.Markdown("# 🚀 Treinador de LoRA para Stable Diffusion")
64
+ gr.Markdown("Treine seus próprios modelos LoRA diretamente no Hugging Face Spaces!")
65
+
66
+ with gr.Tab("Configuração"):
67
+ dataset_zip = gr.File(label="Dataset (ZIP com imagens)", file_types=['.zip'])
68
+ model_name = gr.Dropdown(
69
+ ["runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-2-1"],
70
+ value="runwayml/stable-diffusion-v1-5",
71
+ label="Modelo Base"
72
+ )
73
+ lora_rank = gr.Slider(4, 64, value=4, step=4, label="Rank LoRA")
74
+ learning_rate = gr.Number(value=1e-4, label="Taxa de Aprendizado")
75
+ num_epochs = gr.Slider(1, 50, value=10, step=1, label="Épocas")
76
+ hub_token = gr.Textbox(label="Token Hugging Face (opcional)", type="password")
77
+
78
+ with gr.Tab("Treinamento"):
79
+ start_btn = gr.Button("🚀 Iniciar Treinamento", variant="primary")
80
+ output = gr.Textbox(label="Logs do Treinamento")
81
+
82
+ start_btn.click(
83
+ fn=train_lora_interface,
84
+ inputs=[dataset_zip, model_name, lora_rank, learning_rate, num_epochs, hub_token],
85
+ outputs=output
86
+ )
87
+
88
+ if __name__ == "__main__":
89
+ demo.launch()