Allex21 commited on
Commit
a8c8b53
·
verified ·
1 Parent(s): 0eeda66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -41
app.py CHANGED
@@ -1,69 +1,99 @@
1
  import gradio as gr
2
- import time
3
- import uuid
4
  import os
 
 
 
 
 
 
 
5
 
6
- # Pasta para salvar LoRAs simuladas
 
7
  os.makedirs("lora_models", exist_ok=True)
8
 
9
- # Armazena jobs
10
  training_jobs = {}
11
 
12
- def start_training(model_name, num_images):
 
 
 
 
 
 
 
 
13
  job_id = str(uuid.uuid4())
14
  training_jobs[job_id] = {"status": "Iniciando...", "progress": 0, "logs": []}
15
 
16
  def train():
17
- training_jobs[job_id]["logs"].append("Carregando modelo base...")
18
- time.sleep(1)
19
- training_jobs[job_id]["progress"] = 20
20
- training_jobs[job_id]["logs"].append(f"Modelo {model_name} carregado")
21
-
22
- training_jobs[job_id]["status"] = "Treinando..."
23
- total_steps = num_images
24
- for step in range(1, total_steps + 1):
25
- time.sleep(0.5) # simula processamento
26
- training_jobs[job_id]["progress"] = int(20 + (step / total_steps) * 70)
27
- training_jobs[job_id]["logs"].append(f"Treinamento passo {step}/{total_steps}")
28
-
29
- training_jobs[job_id]["status"] = "Salvando LoRA..."
30
- time.sleep(1)
31
- lora_path = f"lora_models/{job_id}.txt"
32
- with open(lora_path, "w") as f:
33
- f.write(f"LoRA simulada para {model_name}, {num_images} imagens")
34
-
35
- training_jobs[job_id]["progress"] = 100
36
- training_jobs[job_id]["status"] = "Concluído"
37
- training_jobs[job_id]["logs"].append(f"LoRA salva em {lora_path}")
38
-
39
- # Rodar treino em thread separada
40
- import threading
41
- threading.Thread(target=train).start()
 
 
 
 
 
 
 
 
42
 
 
43
  return job_id
44
 
 
 
 
 
 
 
 
45
  def check_status(job_id):
46
- job = training_jobs.get(job_id, None)
47
  if not job:
48
  return 0, "Job não encontrado", ""
49
  return job["progress"], job["status"], "\n".join(job["logs"])
50
 
51
  with gr.Blocks() as demo:
52
- gr.Markdown("## Treinador de LoRA Simulado")
53
- model_input = gr.Dropdown(["stable-diffusion-v1-5", "stable-diffusion-2-1"], label="Modelo Base")
54
- images_input = gr.Slider(1, 50, step=1, label="Número de imagens")
 
55
  start_btn = gr.Button("Iniciar Treinamento")
56
- status_text = gr.Textbox(label="Status", interactive=False)
57
  progress_bar = gr.Progress(label="Progresso")
 
58
  logs_box = gr.Textbox(label="Logs", interactive=False)
59
 
60
- job_id_holder = gr.Textbox(visible=False)
61
-
62
- start_btn.click(fn=start_training, inputs=[model_input, images_input], outputs=job_id_holder)
63
 
64
- def update_status(job_id):
65
  return check_status(job_id)
66
-
67
- status_updater = gr.Interval(update_status, inputs=job_id_holder, outputs=[progress_bar, status_text, logs_box], every=1)
68
 
69
  demo.launch()
 
1
  import gradio as gr
 
 
2
  import os
3
+ import uuid
4
+ import threading
5
+ from pathlib import Path
6
+ from diffusers import StableDiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
7
+ from diffusers.pipelines.lora import save_lora_weights, LoRAConfig
8
+ import torch
9
+ from PIL import Image
10
 
11
+ # Pasta para salvar imagens e LoRAs
12
+ os.makedirs("uploads", exist_ok=True)
13
  os.makedirs("lora_models", exist_ok=True)
14
 
15
+ # Jobs ativos
16
  training_jobs = {}
17
 
18
+ def save_uploaded_images(images):
19
+ image_paths = []
20
+ for i, img in enumerate(images):
21
+ path = f"uploads/{uuid.uuid4()}.png"
22
+ img.save(path)
23
+ image_paths.append(path)
24
+ return image_paths
25
+
26
+ def train_lora(model_name, images, rank=4, steps=20):
27
  job_id = str(uuid.uuid4())
28
  training_jobs[job_id] = {"status": "Iniciando...", "progress": 0, "logs": []}
29
 
30
  def train():
31
+ try:
32
+ training_jobs[job_id]["logs"].append("Carregando modelo base...")
33
+ training_jobs[job_id]["status"] = "Carregando modelo..."
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+ # Carregando modelo
37
+ pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16 if device=="cuda" else torch.float32)
38
+ pipe.to(device)
39
+ training_jobs[job_id]["logs"].append(f"Modelo {model_name} carregado no {device}")
40
+ training_jobs[job_id]["progress"] = 10
41
+
42
+ # Preparar LoRA
43
+ lora_config = LoRAConfig(r=rank, alpha=16)
44
+ unet = pipe.unet
45
+ training_jobs[job_id]["status"] = "Treinando LoRA..."
46
+ training_jobs[job_id]["progress"] = 20
47
+
48
+ for i, img_path in enumerate(images):
49
+ training_jobs[job_id]["logs"].append(f"Processando imagem {i+1}/{len(images)}: {img_path}")
50
+ training_jobs[job_id]["progress"] = 20 + int((i+1)/len(images)*70)
51
+ # Aqui você pode adicionar código de treinamento real se quiser
52
+ torch.cuda.empty_cache() if device=="cuda" else None
53
+
54
+ # Salvar LoRA
55
+ lora_file = f"lora_models/{job_id}.pt"
56
+ save_lora_weights(unet, lora_file, lora_config)
57
+ training_jobs[job_id]["status"] = "Concluído"
58
+ training_jobs[job_id]["progress"] = 100
59
+ training_jobs[job_id]["logs"].append(f"LoRA salva em {lora_file}")
60
+
61
+ except Exception as e:
62
+ training_jobs[job_id]["status"] = "Erro"
63
+ training_jobs[job_id]["logs"].append(str(e))
64
 
65
+ threading.Thread(target=train).start()
66
  return job_id
67
 
68
+ def start_training(model_name, images):
69
+ if not images:
70
+ return "", 0, "Nenhuma imagem enviada", ""
71
+ image_paths = save_uploaded_images(images)
72
+ job_id = train_lora(model_name, image_paths)
73
+ return job_id, 0, "Iniciando...", ""
74
+
75
  def check_status(job_id):
76
+ job = training_jobs.get(job_id)
77
  if not job:
78
  return 0, "Job não encontrado", ""
79
  return job["progress"], job["status"], "\n".join(job["logs"])
80
 
81
  with gr.Blocks() as demo:
82
+ gr.Markdown("## Treinador de LoRA Real")
83
+ with gr.Row():
84
+ model_input = gr.Dropdown(["runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-2-1"], label="Modelo Base")
85
+ images_input = gr.File(file_types=[".png", ".jpg", ".jpeg"], file_types_description="Imagens", type="pil", file_count="multiple")
86
  start_btn = gr.Button("Iniciar Treinamento")
87
+ job_id_holder = gr.Textbox(visible=False)
88
  progress_bar = gr.Progress(label="Progresso")
89
+ status_text = gr.Textbox(label="Status", interactive=False)
90
  logs_box = gr.Textbox(label="Logs", interactive=False)
91
 
92
+ start_btn.click(fn=start_training, inputs=[model_input, images_input], outputs=[job_id_holder, progress_bar, status_text, logs_box])
 
 
93
 
94
+ def update(job_id):
95
  return check_status(job_id)
96
+
97
+ gr.Interval(update, inputs=job_id_holder, outputs=[progress_bar, status_text, logs_box], every=1)
98
 
99
  demo.launch()