File size: 39,393 Bytes
0eeda66
8f49e8c
a8c8b53
8f49e8c
a8c8b53
8f49e8c
 
a8c8b53
8f49e8c
 
 
 
 
a8c8b53
 
8f49e8c
 
 
 
 
 
 
 
8da6913
8f49e8c
404dd88
8f49e8c
 
 
 
 
 
 
119d0f5
 
 
 
 
3e57297
119d0f5
3e57297
119d0f5
 
 
 
3e57297
119d0f5
 
 
3e57297
119d0f5
 
 
 
 
 
3e57297
119d0f5
 
 
 
3e57297
119d0f5
 
3e57297
119d0f5
 
 
 
 
 
 
3e57297
119d0f5
3e57297
119d0f5
 
3e57297
 
119d0f5
3e57297
119d0f5
3e57297
119d0f5
 
 
 
 
 
 
3e57297
119d0f5
 
 
 
3e57297
119d0f5
3e57297
 
119d0f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e57297
119d0f5
 
3e57297
119d0f5
 
 
 
 
 
 
3e57297
119d0f5
 
3e57297
119d0f5
 
 
 
 
3e57297
119d0f5
 
8f49e8c
119d0f5
 
8f49e8c
 
 
3e57297
 
 
 
8f49e8c
 
 
3e57297
 
8f49e8c
3e57297
8f49e8c
 
3e57297
 
 
119d0f5
 
 
 
 
 
 
3e57297
119d0f5
 
 
 
394ecc3
 
 
 
3e57297
8da6913
 
 
 
 
 
 
 
3e57297
8da6913
a82fdb3
3e57297
 
8da6913
 
119d0f5
3e57297
 
119d0f5
3e57297
8f49e8c
119d0f5
 
3e57297
119d0f5
 
 
 
 
 
3e57297
8f49e8c
 
119d0f5
 
3e57297
119d0f5
8f49e8c
119d0f5
8f49e8c
 
3e57297
119d0f5
 
3e57297
 
119d0f5
 
 
 
3e57297
119d0f5
 
 
3e57297
119d0f5
 
 
3e57297
119d0f5
3e57297
 
119d0f5
 
 
3e57297
119d0f5
8da6913
119d0f5
3e57297
119d0f5
3e57297
 
119d0f5
 
 
 
3e57297
8f49e8c
 
119d0f5
 
 
3e57297
119d0f5
3e57297
8f49e8c
 
119d0f5
8f49e8c
 
119d0f5
394ecc3
404dd88
 
 
119d0f5
3e57297
8f49e8c
 
 
8da6913
8f49e8c
 
394ecc3
8f49e8c
 
 
 
 
 
 
 
 
 
 
119d0f5
3e57297
8f49e8c
3e57297
 
 
119d0f5
3e57297
 
 
 
 
 
 
 
 
 
 
 
 
 
8f49e8c
 
 
119d0f5
3e57297
8f49e8c
 
 
 
3e57297
 
 
119d0f5
8f49e8c
3e57297
119d0f5
3e57297
 
 
119d0f5
 
8f49e8c
 
 
 
3e57297
 
119d0f5
5e10fc1
3e57297
119d0f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e57297
119d0f5
 
 
 
 
 
 
 
8f49e8c
119d0f5
10a7187
3e57297
10a7187
 
8f49e8c
3e57297
8f49e8c
 
119d0f5
 
 
 
 
 
 
 
 
3e57297
119d0f5
 
 
 
 
 
 
3e57297
119d0f5
 
 
 
 
 
 
3e57297
8f49e8c
119d0f5
8f49e8c
3e57297
8f49e8c
3e57297
119d0f5
 
 
 
 
 
3e57297
8f49e8c
a8c8b53
119d0f5
3e57297
8f49e8c
0eeda66
8f49e8c
3e57297
 
 
119d0f5
3e57297
119d0f5
3e57297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119d0f5
3e57297
 
119d0f5
 
3e57297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10a7187
3e57297
 
 
 
 
 
 
119d0f5
 
3e57297
 
 
 
 
 
 
 
 
 
 
 
 
119d0f5
8f49e8c
3e57297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119d0f5
 
3e57297
119d0f5
3e57297
 
 
 
 
 
 
119d0f5
3e57297
 
 
 
 
 
119d0f5
 
 
3e57297
 
f8c89ea
3e57297
 
 
 
f8c89ea
3e57297
 
119d0f5
3e57297
 
 
119d0f5
3e57297
8f49e8c
119d0f5
3e57297
 
 
 
119d0f5
 
3e57297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f49e8c
 
 
119d0f5
3e57297
119d0f5
3e57297
 
 
 
 
 
 
 
 
 
 
 
 
 
8f49e8c
119d0f5
3e57297
 
 
 
119d0f5
3e57297
 
 
 
 
 
 
119d0f5
3e57297
 
 
 
 
119d0f5
3e57297
8f49e8c
3e57297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119d0f5
 
3e57297
 
 
 
 
119d0f5
3e57297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119d0f5
3e57297
 
 
 
 
 
 
 
 
8f49e8c
3e57297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f49e8c
 
 
3e57297
8f49e8c
3e57297
 
 
 
119d0f5
3e57297
 
5e10fc1
 
 
3e57297
5e10fc1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
import os
import json
import uuid
import shutil
import threading
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple
import zipfile
import tempfile

import gradio as gr
import torch
from PIL import Image
import numpy as np
from diffusers import (
    StableDiffusionPipeline,
    UNet2DConditionModel,
    DDPMScheduler,
    AutoencoderKL
)
from transformers import CLIPTextModel, CLIPTokenizer
from peft import LoraConfig
import logging
from safetensors.torch import save_file

# Configurar logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class LoRAImageTrainer:
    """Classe principal para treinamento de modelos LoRA para geração de imagens otimizada para baixo uso de GPU."""

    def __init__(self):  
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
        self.training_jobs = {}  
        self.models_cache = {}  
          
    def get_available_models(self) -> List[str]:  
        """Retorna lista de modelos base disponíveis para treinamento LoRA."""  
        return [  
            "runwayml/stable-diffusion-v1-5",  
            "stabilityai/stable-diffusion-2-1",  
            "CompVis/stable-diffusion-v1-4"  
            # XL removido por ser pesado demais para Spaces gratuitos
        ]  
  
    def load_base_model(self, model_name: str):  
        """Carrega modelo base de difusão com otimizações para baixo uso de GPU."""  
        try:  
            if model_name in self.models_cache:  
                return self.models_cache[model_name]  
              
            logger.info(f"Carregando modelo base: {model_name}")  
              
            # Configurações para otimização de memória  
            model_kwargs = {  
                "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,  
                "use_safetensors": True,  
                "variant": "fp16" if torch.cuda.is_available() else None,  
                "safety_checker": None,  # Desativa verificador de segurança para economizar memória
            }  
              
            # Carregar pipeline completo  
            pipeline = StableDiffusionPipeline.from_pretrained(  
                model_name,  
                **model_kwargs  
            )  
              
            if torch.cuda.is_available():  
                pipeline = pipeline.to(self.device)  
                # Habilitar attention slicing para economia de memória  
                pipeline.enable_attention_slicing()  
                # Habilitar memory efficient attention se disponível  
                try:  
                    pipeline.enable_xformers_memory_efficient_attention()  
                except Exception as e:  
                    logger.warning("xformers não disponível, usando attention padrão")  
              
            # Cache do modelo  
            self.models_cache[model_name] = pipeline  
              
            return pipeline  
              
        except Exception as e:  
            logger.error(f"Erro ao carregar modelo {model_name}: {str(e)}")  
            raise e  
  
    def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:  
        """Prepara dataset de imagens para treinamento."""  
        dataset = []  
          
        for img_path, caption in zip(image_files, captions):  
            try:  
                # Carregar e redimensionar imagem  
                image = Image.open(img_path).convert("RGB")  
                  
                # Redimensionar mantendo aspect ratio  
                image = self.resize_image(image, resolution)  
                  
                dataset.append({  
                    "image": image,  
                    "caption": caption,  
                    "image_path": img_path  
                })  
                  
            except Exception as e:  
                logger.error(f"Erro ao processar imagem {img_path}: {str(e)}")  
                continue  
          
        return dataset  
  
    def resize_image(self, image: Image.Image, target_size: int) -> Image.Image:  
        """Redimensiona imagem mantendo aspect ratio e fazendo crop central se necessário."""  
        width, height = image.size  
          
        # Calcular novo tamanho mantendo aspect ratio  
        if width > height:  
            new_width = target_size  
            new_height = int((height * target_size) / width)  
        else:  
            new_height = target_size  
            new_width = int((width * target_size) / height)  
          
        # Redimensionar  
        image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)  
          
        # Crop central para obter tamanho exato  
        if new_width != target_size or new_height != target_size:  
            left = (new_width - target_size) // 2  
            top = (new_height - target_size) // 2  
            right = left + target_size  
            bottom = top + target_size  
              
            image = image.crop((left, top, right, bottom))  
          
        return image
  
    def real_lora_training(self,   
                         job_id: str,
                         model_name: str,
                         dataset: List[Dict],
                         r: int = 16,
                         lora_alpha: int = 32,
                         lora_dropout: float = 0.1,
                         num_epochs: int = 10,
                         learning_rate: float = 1e-4,
                         batch_size: int = 1,
                         resolution: int = 512) -> None:
        """TREINAMENTO REAL DE LoRA PARA IMAGENS - CORRIGIDO PARA DIFFUSERS + PEFT."""
        
        try:
            # Atualizar status
            self.training_jobs[job_id]["status"] = "loading_model"
            self.training_jobs[job_id]["progress"] = 5
            self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}")
            
            # Carregar modelo base
            pipeline = self.load_base_model(model_name)
            unet = pipeline.unet
            text_encoder = pipeline.text_encoder
            vae = pipeline.vae
            tokenizer = pipeline.tokenizer
            scheduler = pipeline.scheduler

            # Congelar parâmetros
            unet.requires_grad_(False)
            text_encoder.requires_grad_(False)
            vae.requires_grad_(False)

            # ✅ ✅ ✅ CORREÇÃO: REMOVER ADAPTADOR EXISTENTE ✅ ✅ ✅
            if hasattr(unet, "peft_config") and "default" in unet.peft_config:
                unet.delete_adapter("default")

            # Criar configuração LoRA
            lora_config = LoraConfig(
                r=r,
                lora_alpha=lora_alpha,
                target_modules=["to_k", "to_q", "to_v", "to_out.0"],
                lora_dropout=lora_dropout,
                bias="none"
            )

            # Aplicar LoRA ao UNet manualmente, sem usar get_peft_model diretamente
            unet.add_adapter(lora_config, adapter_name="default")

            # Ativar o adaptador
            unet.set_adapter("default")
            unet.train()
            unet.to(self.device)

            # Otimizador
            optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)

            # Preparar scheduler para treinamento
            self.training_jobs[job_id]["status"] = "preparing_data"
            self.training_jobs[job_id]["progress"] = 20

            # Normalização de imagem
            def preprocess_image(image):
                image = np.array(image).astype(np.float32) / 255.0
                image = image.transpose(2, 0, 1)
                image = torch.from_numpy(image).unsqueeze(0)
                return image

            # Loop de treinamento real
            total_steps = num_epochs * len(dataset)
            current_step = 0

            self.training_jobs[job_id]["status"] = "training"
            self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real...")

            for epoch in range(num_epochs):
                for item in dataset:
                    current_step += 1
                    
                    # Obter imagem e legenda
                    image = item["image"]
                    caption = item["caption"]

                    # Pré-processar imagem
                    image_tensor = preprocess_image(image).to(self.device)
                    if torch.cuda.is_available():
                        image_tensor = image_tensor.half()

                    # Codificar imagem para latentes
                    with torch.no_grad():
                        latents = vae.encode(image_tensor * 2 - 1).latent_dist.sample() * 0.18215

                    # Tokenizar texto
                    inputs = tokenizer(caption, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
                    input_ids = inputs.input_ids.to(self.device)

                    # Gerar timesteps aleatórios
                    timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (1,), device=self.device).long()

                    # Adicionar ruído aos latentes
                    noise = torch.randn_like(latents)
                    noisy_latents = scheduler.add_noise(latents, noise, timesteps)

                    # Forward pass
                    encoder_hidden_states = text_encoder(input_ids)[0]
                    noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample

                    # Calcular perda
                    loss = torch.nn.functional.mse_loss(noise_pred, noise)

                    # Backward pass
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    # Atualizar progresso
                    progress = 30 + int((current_step / total_steps) * 60)
                    self.training_jobs[job_id]["progress"] = min(progress, 90)

                    if current_step % max(1, len(dataset)//2) == 0:
                        log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
                        self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - {log_msg}")

            # Salvar LoRA
            self.training_jobs[job_id]["status"] = "saving"
            self.training_jobs[job_id]["progress"] = 95

            output_dir = f"./lora_models/{job_id}"
            os.makedirs(output_dir, exist_ok=True)

            # ✅ SALVAR APENAS OS PESOS DO LORA (NÃO O MODELO BASE)
            adapter_state_dict = unet.state_dict()
            adapter_weights = {k: v for k, v in adapter_state_dict.items() if "lora_" in k}
            save_file(adapter_weights, f"{output_dir}/adapter_model.safetensors")

            # Criar adapter_config.json
            lora_config_dict = {
                "r": r,
                "lora_alpha": lora_alpha,
                "target_modules": ["to_k", "to_q", "to_v", "to_out.0"],
                "lora_dropout": lora_dropout,
                "bias": "none",
                "task_type": "CAUSAL_LM",
                "base_model_name": model_name,
                "training_info": {
                    "num_epochs": num_epochs,
                    "learning_rate": learning_rate,
                    "batch_size": batch_size,
                    "resolution": resolution,
                    "num_images": len(dataset)
                }
            }
            with open(f"{output_dir}/adapter_config.json", "w") as f:
                json.dump(lora_config_dict, f, indent=2)

            # README
            readme_content = f"""# LoRA Model - {job_id}

Informações do Treinamento

Modelo Base: {model_name}
Rank (r): {r}
LoRA Alpha: {lora_alpha}
Dropout: {lora_dropout}
Épocas: {num_epochs}
Taxa de Aprendizado: {learning_rate}
Resolução: {resolution}x{resolution}
Número de Imagens: {len(dataset)}
Data de Treinamento: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

Como Usar

1. Baixe os arquivos adapter_config.json e adapter_model.safetensors
2. Carregue em sua ferramenta de geração de imagens favorita (ComfyUI, Automatic1111, etc.)
3. Use o trigger word ou estilo aprendido durante o treinamento
"""
            with open(f"{output_dir}/README.md", "w") as f:
                f.write(readme_content)

            # Finalizar
            self.training_jobs[job_id]["status"] = "completed"
            self.training_jobs[job_id]["progress"] = 100
            self.training_jobs[job_id]["model_path"] = output_dir
            self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
            self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ✅ Treinamento REAL concluído! LoRA salvo em {output_dir}")

            logger.info(f"Treinamento LoRA REAL concluído para job {job_id}")

        except Exception as e:
            error_msg = f"Erro REAL no treinamento: {str(e)}"
            logger.error(error_msg)
            self.training_jobs[job_id]["status"] = "error"
            self.training_jobs[job_id]["error"] = error_msg
            self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ❌ {error_msg}")

    def start_training(self,
                      model_name: str,
                      image_files: List[str],
                      captions: List[str],
                      **kwargs) -> str:
        """Inicia treinamento LoRA assíncrono."""

        job_id = str(uuid.uuid4())  

        # Preparar dataset  
        dataset = self.prepare_image_dataset(image_files, captions, kwargs.get('resolution', 512))  

        self.training_jobs[job_id] = {  
            "id": job_id,  
            "status": "queued",  
            "progress": 0,  
            "created_at": datetime.now().isoformat(),  
            "model_name": model_name,  
            "num_images": len(dataset),  
            "logs": [],  
            "error": None,  
            "model_path": None,  
            "completed_at": None  
        }  

        # Iniciar treinamento em thread separada  
        thread = threading.Thread(  
            target=self.real_lora_training,  
            args=(job_id, model_name, dataset),  
            kwargs=kwargs  
        )  
        thread.daemon = True  
        thread.start()  

        return job_id

    def get_training_status(self, job_id: str) -> Dict[str, Any]:
        """Retorna status do treinamento."""
        return self.training_jobs.get(job_id, {"error": "Job não encontrado"})

    def list_trained_models(self) -> List[Dict[str, str]]:
        """Lista modelos LoRA treinados."""
        models = []
        lora_models_dir = Path("./lora_models")

        if lora_models_dir.exists():  
            for model_dir in lora_models_dir.iterdir():  
                if model_dir.is_dir():  
                    config_file = model_dir / "adapter_config.json"  
                    if config_file.exists():  
                        try:  
                            with open(config_file, 'r') as f:  
                                config = json.load(f)  

                            models.append({  
                                "id": model_dir.name,  
                                "path": str(model_dir),  
                                "base_model": config.get("base_model_name", "Unknown"),  
                                "r": config.get("r", "Unknown"),  
                                "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()  
                            })  
                        except Exception as e:  
                            models.append({  
                                "id": model_dir.name,  
                                "path": str(model_dir),  
                                "base_model": "Unknown",  
                                "r": "Unknown",  
                                "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()  
                            })  

        return models

    def create_download_zip(self, model_path: str) -> str:
        """Cria um arquivo ZIP com os arquivos do modelo LoRA para download."""
        zip_path = f"{model_path}.zip"

        with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:  
            model_dir = Path(model_path)  
            for file_path in model_dir.rglob('*'):  
                if file_path.is_file():  
                    arcname = file_path.relative_to(model_dir)  
                    zipf.write(file_path, arcname)  

        return zip_path


# Instância global do trainer
trainer = LoRAImageTrainer()

def create_gradio_interface():
    """Cria interface Gradio para a ferramenta LoRA de geração de imagens."""

    # CSS personalizado para responsividade móvel  
    custom_css = """  
    /* Mobile-first responsive design */  
    @media (max-width: 768px) {  
        .gradio-container {  
            padding: 8px !important;  
            margin: 0 !important;  
        }  
          
        .tab-nav {  
            flex-wrap: wrap !important;  
            gap: 4px !important;  
        }  
          
        .tab-nav button {  
            font-size: 14px !important;  
            padding: 8px 12px !important;  
            min-width: auto !important;  
            flex: 1 1 auto !important;  
        }  
          
        .form-container {  
            padding: 12px !important;  
        }  
          
        .btn {  
            width: 100% !important;  
            padding: 12px !important;  
            font-size: 16px !important;  
            margin-bottom: 8px !important;  
            min-height: 44px !important;  
        }  
          
        .textbox textarea {  
            font-size: 16px !important;  
            min-height: 120px !important;  
        }  
          
        .dropdown select {  
            font-size: 16px !important;  
            padding: 12px !important;  
        }  
          
        .output-text {  
            font-size: 14px !important;  
            line-height: 1.5 !important;  
        }  
          
        .column {  
            margin-bottom: 16px !important;  
        }  
          
        .file-upload {  
            min-height: 100px !important;  
        }  
    }  
      
    /* Enhanced visual styles */  
    .lora-header {  
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);  
        color: white;  
        padding: 20px;  
        border-radius: 12px;  
        margin-bottom: 20px;  
        text-align: center;  
        box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);  
    }  
      
    .status-indicator {  
        display: inline-block;  
        padding: 4px 8px;  
        border-radius: 6px;  
        font-size: 12px;  
        font-weight: 600;  
        text-transform: uppercase;  
        letter-spacing: 0.5px;  
        margin-right: 8px;  
    }  
      
    .status-queued { background-color: #fbbf24; color: #92400e; }  
    .status-loading_model { background-color: #60a5fa; color: #1e40af; }  
    .status-preparing_lora { background-color: #8b5cf6; color: #5b21b6; }  
    .status-preparing_data { background-color: #06b6d4; color: #0e7490; }  
    .status-training { background-color: #a78bfa; color: #5b21b6; }  
    .status-saving { background-color: #f59e0b; color: #92400e; }  
    .status-completed { background-color: #34d399; color: #065f46; }  
    .status-error { background-color: #f87171; color: #991b1b; }  
      
    /* Touch device optimizations */  
    @media (hover: none) and (pointer: coarse) {  
        .btn {  
            min-height: 44px !important;  
            min-width: 44px !important;  
        }  
          
        .tab-nav button {  
            min-height: 44px !important;  
            min-width: 44px !important;  
        }  
    }  
    """  
    
    def process_images_and_captions(files, captions_text):  
        """Processa imagens e legendas enviadas pelo usuário."""  
        if not files:  
            return "❌ Erro: Nenhuma imagem foi enviada!"  
        
        # Processar legendas  
        captions = []  
        if captions_text.strip():  
            captions = [line.strip() for line in captions_text.split('\n') if line.strip()]  
        
        # Se não há legendas suficientes, usar legendas padrão  
        while len(captions) < len(files):  
            captions.append(f"training image {len(captions) + 1}")  
        
        # Truncar legendas se houver mais que imagens  
        captions = captions[:len(files)]  
        
        return files, captions
    
    def start_training_wrapper(model_name, files, captions_text, trigger_word, r, lora_alpha, lora_dropout,   
                             num_epochs, learning_rate, batch_size, resolution):  
        """Wrapper para iniciar treinamento via Gradio."""  
            
        if not files:  
            return "❌ Erro: Nenhuma imagem foi enviada para treinamento!"  
        
        if len(files) < 3:  
            return "❌ Erro: Forneça pelo menos 3 imagens para treinamento!"  
        
        try:  
            # Processar imagens e legendas  
            image_files = [f.name for f in files]  
              
            # Processar legendas  
            captions = []  
            if captions_text.strip():  
                captions = [line.strip() for line in captions_text.split('\n') if line.strip()]  
            
            # Se não há legendas suficientes, usar trigger word + descrição padrão  
            while len(captions) < len(files):  
                if trigger_word.strip():  
                    captions.append(f"{trigger_word.strip()}, high quality photo")  
                else:  
                    captions.append(f"training image {len(captions) + 1}, high quality photo")  
            
            # Truncar legendas se houver mais que imagens  
            captions = captions[:len(files)]  
            
            job_id = trainer.start_training(  
                model_name=model_name,  
                image_files=image_files,  
                captions=captions,  
                r=int(r),  
                lora_alpha=int(lora_alpha),  
                lora_dropout=float(lora_dropout),  
                num_epochs=int(num_epochs),  
                learning_rate=float(learning_rate),  
                batch_size=int(batch_size),  
                resolution=int(resolution)  
            )  
            
            return f"✅ Treinamento REAL iniciado! ID do Job: {job_id}\n\n📊 Imagens: {len(files)}\n🏷️ Trigger Word: {trigger_word or 'Nenhuma'}\n\nUse o ID acima para verificar o progresso na aba 'Status do Treinamento'."  
            
        except Exception as e:  
            return f"❌ Erro ao iniciar treinamento: {str(e)}"  
    
    def check_status_wrapper(job_id):  
        """Wrapper para verificar status via Gradio."""  
        if not job_id.strip():  
            return "❌ Erro: Forneça um ID de job válido!"  
        
        status = trainer.get_training_status(job_id.strip())  
        
        if "error" in status and status["error"] == "Job não encontrado":  
            return "❌ Job não encontrado! Verifique o ID."  
        
        # Criar indicador visual de status  
        status_class = f"status-{status['status']}"  
        status_emoji = {  
            'queued': '⏳',  
            'loading_model': '📥',  
            'preparing_lora': '⚙️',  
            'preparing_data': '📊',  
            'training': '🏋️',  
            'saving': '💾',  
            'completed': '✅',  
            'error': '❌'  
        }.get(status['status'], '📊')  
        
        # Barra de progresso visual  
        progress = status['progress']  
        progress_bar = f"""  
        <div style="width: 100%; background-color: #e5e7eb; border-radius: 4px; overflow: hidden; margin: 8px 0;">  
            <div style="width: {progress}%; height: 8px; background: linear-gradient(90deg, #3b82f6, #8b5cf6); transition: width 0.3s ease; border-radius: 4px;"></div>  
        </div>  
        """  
        
        status_text = f"""

📊 Status do Treinamento LoRA

🆔 Job ID: {status['id']}
{status_emoji} Status: <span class="{status_class}">{status['status'].upper().replace('_', ' ')}</span>
⏳ Progresso: {status['progress']}%

{progress_bar}

🤖 Modelo Base: {status['model_name']}
🖼️ Imagens: {status.get('num_images', 'N/A')}
📅 Criado em: {status['created_at']}

"""

        if status['logs']:  
            status_text += "📝 **Logs Recentes:**\n"  
            for log in status['logs'][-5:]:  # Últimos 5 logs  
                status_text += f"• {log}\n"  
        
        if status['status'] == 'completed':  
            status_text += f"\n✅ **Treinamento Concluído!**\n📁 **Modelo salvo em:** {status['model_path']}"  
            status_text += f"\n⏰ **Concluído em:** {status['completed_at']}"  
            status_text += f"\n\n💡 **Próximos passos:** Vá para a aba 'Modelos Treinados' para baixar seu LoRA!"  
        elif status['status'] == 'error':  
            status_text += f"\n❌ **Erro:** {status['error']}"  
        
        return status_text
    
    def list_models_wrapper():  
        """Wrapper para listar modelos via Gradio."""  
        models = trainer.list_trained_models()  
        
        if not models:  
            return "📭 Nenhum modelo LoRA treinado encontrado."  
        
        models_text = "📚 **Modelos LoRA Treinados:**\n\n"  
        for model in models:  
            models_text += f"🆔 **ID:** {model['id']}\n"  
            models_text += f"🤖 **Modelo Base:** {model['base_model']}\n"  
            models_text += f"📊 **Rank (r):** {model['r']}\n"  
            models_text += f"📁 **Caminho:** {model['path']}\n"  
            models_text += f"📅 **Criado:** {model['created']}\n\n"  
            models_text += "---\n\n"  
        
        return models_text
    
    def download_model_wrapper(job_id):  
        """Wrapper para preparar download do modelo."""  
        if not job_id.strip():  
            return None, "❌ Erro: Forneça um ID de job válido!"  
        
        status = trainer.get_training_status(job_id.strip())  
        
        if "error" in status and status["error"] == "Job não encontrado":  
            return None, "❌ Job não encontrado! Verifique o ID."  
        
        if status['status'] != 'completed':  
            return None, f"❌ Treinamento ainda não foi concluído. Status atual: {status['status']}"  
        
        try:  
            model_path = status['model_path']  
            zip_path = trainer.create_download_zip(model_path)  
            
            return zip_path, f"✅ Arquivo ZIP criado com sucesso! Clique no link acima para baixar."  
            
        except Exception as e:  
            return None, f"❌ Erro ao criar arquivo de download: {str(e)}"  
    
    # Interface Gradio  
    with gr.Blocks(  
        title="🎨 LoRA Image Trainer - Criador e Treinador de LoRA para Imagens",   
        theme=gr.themes.Soft(),  
        css=custom_css  
    ) as interface:  
        
        gr.HTML("""  
        <div class="lora-header">  
            <h1>🎨 LoRA Image Trainer</h1>  
            <p>Criador e Treinador de LoRA para Geração de Imagens</p>  
            <p style="font-size: 0.9em; opacity: 0.9; margin-top: 8px;">  
                Ferramenta otimizada para baixo uso de GPU, compatível com dispositivos móveis  
            </p>  
        </div>  
        """)  
        
        with gr.Tabs():  
            
            # Aba de Treinamento  
            with gr.TabItem("🎯 Treinar LoRA"):  
                gr.Markdown("### Configurar e Iniciar Treinamento LoRA para Imagens")  
                
                with gr.Row():  
                    with gr.Column(scale=2):  
                        model_dropdown = gr.Dropdown(  
                            choices=trainer.get_available_models(),  
                            value="runwayml/stable-diffusion-v1-5",  
                            label="🤖 Modelo Base",  
                        )  
                        
                        image_files = gr.File(  
                            file_count="multiple",  
                            file_types=["image"],  
                            label="🖼️ Imagens de Treinamento",  
                        )  
                        
                        trigger_word = gr.Textbox(  
                            label="🏷️ Trigger Word (Opcional)",  
                            placeholder="ex: meuEstilo, minhaPersonagem, etc.",  
                        )  
                        
                        captions_text = gr.Textbox(  
                            lines=8,  
                            placeholder="Digite uma legenda por linha (opcional)...\n\nExemplo:\nmeuEstilo, retrato de uma mulher\nmeuEstilo, homem sorrindo\nmeuEstilo, paisagem urbana\n\nSe deixar vazio, usará a trigger word + 'high quality photo'",  
                            label="📝 Legendas das Imagens (Opcional)",  
                        )  
                    
                    with gr.Column(scale=1):  
                        gr.Markdown("### ⚙️ Parâmetros LoRA")  
                        
                        r = gr.Slider(  
                            minimum=4, maximum=64, value=8, step=4,  # reduzido max para 64
                            label="r (Rank)",  
                        )  
                        
                        lora_alpha = gr.Slider(  
                            minimum=1, maximum=64, value=16, step=1,  # reduzido max para 64
                            label="LoRA Alpha",  
                        )  
                        
                        lora_dropout = gr.Slider(  
                            minimum=0.0, maximum=0.5, value=0.0, step=0.05,  # dropout 0 para mais estabilidade
                            label="LoRA Dropout",  
                        )  
                        
                        gr.Markdown("### 🏋️ Parâmetros de Treinamento")  
                        
                        num_epochs = gr.Slider(  
                            minimum=5, maximum=20, value=10, step=5,  # reduzido max para 20
                            label="Épocas",  
                        )  
                        
                        learning_rate = gr.Slider(  
                            minimum=1e-5, maximum=5e-4, value=1e-4, step=1e-5,  # reduzido max
                            label="Taxa de Aprendizado",  
                        )  
                        
                        batch_size = gr.Slider(  
                            minimum=1, maximum=1, value=1, step=1,  # fixado em 1 para Spaces
                            label="Batch Size",  
                        )  
                        
                        resolution = gr.Dropdown(  
                            choices=[512],  # fixado em 512 para garantir funcionamento em GPU limitada
                            value=512,  
                            label="Resolução",  
                        )  
            
                train_button = gr.Button("🚀 Iniciar Treinamento LoRA", variant="primary", size="lg")  
                train_output = gr.Textbox(label="📊 Resultado", lines=5)  
                
                train_button.click(  
                    start_training_wrapper,  
                    inputs=[model_dropdown, image_files, captions_text, trigger_word, r, lora_alpha, lora_dropout,   
                           num_epochs, learning_rate, batch_size, resolution],  
                    outputs=train_output  
                )  
            
            # Aba de Status  
            with gr.TabItem("📊 Status do Treinamento"):  
                gr.Markdown("### Verificar Progresso do Treinamento")  
                
                job_id_input = gr.Textbox(  
                    label="🆔 ID do Job",  
                    placeholder="Cole aqui o ID do job de treinamento...",  
                )  
                
                status_button = gr.Button("🔍 Verificar Status", variant="secondary")  
                status_output = gr.Textbox(label="📈 Status", lines=12)  
                
                status_button.click(  
                    check_status_wrapper,  
                    inputs=job_id_input,  
                    outputs=status_output  
                )  
                
                gr.Markdown("💡 **Dica:** Atualize o status regularmente para acompanhar o progresso do treinamento.")  
            
            # Aba de Modelos e Download  
            with gr.TabItem("📚 Modelos e Download"):  
                gr.Markdown("### Visualizar e Baixar Modelos LoRA Treinados")  
                
                with gr.Row():  
                    with gr.Column(scale=1):  
                        list_button = gr.Button("📋 Listar Modelos", variant="secondary")  
                        models_output = gr.Textbox(label="📚 Modelos Disponíveis", lines=10)  
                        
                        list_button.click(  
                            list_models_wrapper,  
                            outputs=models_output  
                        )  
                    
                    with gr.Column(scale=1):  
                        gr.Markdown("#### 💾 Download de Modelo")  
                        
                        download_job_id = gr.Textbox(  
                            label="🆔 ID do Job para Download",  
                            placeholder="Cole o ID do job concluído...",  
                        )  
                        
                        download_button = gr.Button("📦 Preparar Download", variant="primary")  
                        download_file = gr.File(label="📁 Arquivo para Download")  
                        download_status = gr.Textbox(label="📊 Status do Download", lines=3)  
                        
                        download_button.click(  
                            download_model_wrapper,  
                            inputs=download_job_id,  
                            outputs=[download_file, download_status]  
                        )  
            
            # Aba de Informações  
            with gr.TabItem("ℹ️ Sobre"):  
                gr.Markdown("""  
                ### 🎯 Sobre o LoRA Image Trainer  
                  
                Esta ferramenta foi desenvolvida para democratizar o acesso ao treinamento de modelos LoRA para geração de imagens,   
                permitindo que qualquer pessoa possa criar adaptações personalizadas de modelos de difusão (como Stable Diffusion)   
                sem a necessidade de hardware especializado.  
                  
                #### ✨ Características Principais:  
                  
                - **🔋 Otimizado para Baixa GPU**: Utiliza técnicas como mixed precision, gradient checkpointing e configurações otimizadas  
                - **📱 Compatível com Móveis**: Interface responsiva que funciona em smartphones e tablets  
                - **⚡ Rápido e Eficiente**: Treinamento otimizado com bibliotecas Diffusers e PEFT do Hugging Face  
                - **🎛️ Configurável**: Controle total sobre parâmetros LoRA e de treinamento  
                - **☁️ Pronto para Deploy**: Facilmente implantável no Hugging Face Spaces  
                - **🎨 Focado em Imagens**: Especificamente projetado para modelos de difusão e geração de imagens  
                  
                #### 🛠️ Tecnologias Utilizadas:  
                  
                - **Hugging Face Diffusers**: Para modelos de difusão e pipeline de treinamento  
                - **PEFT (Parameter-Efficient Fine-Tuning)**: Para treinamento eficiente de LoRA  
                - **PyTorch**: Framework de deep learning  
                - **Gradio**: Interface web interativa e responsiva  
                - **LoRA (Low-Rank Adaptation)**: Técnica de fine-tuning eficiente para modelos de difusão  
                  
                #### 📖 Como Usar:  
                  
                1. **Prepare suas imagens**: Colete 3-50 imagens de alta qualidade do estilo/conceito que deseja treinar  
                2. **Escolha um modelo base** na aba "Treinar LoRA" (recomendado: Stable Diffusion 1.5)  
                3. **Faça upload das imagens** e defina uma trigger word (palavra-chave)  
                4. **Configure os parâmetros** conforme necessário (valores padrão funcionam bem)  
                5. **Inicie o treinamento** e anote o ID do job  
                6. **Acompanhe o progresso** na aba "Status do Treinamento"  
                7. **Baixe seu LoRA** na aba "Modelos e Download" quando concluído  
                8. **Use em suas ferramentas favoritas** (ComfyUI, Automatic1111, etc.)  
                  
                #### 💡 Dicas para Melhores Resultados:  
                  
                - **Qualidade > Quantidade**: 10-20 imagens de alta qualidade são melhores que 50 imagens ruins  
                - **Consistência**: Use imagens com estilo/conceito consistente  
                - **Resolução**: Para GPUs com pouca VRAM, use resolução 512x512  
                - **Trigger Word**: Escolha uma palavra única e fácil de lembrar  
                - **Legendas**: Descreva o que há nas imagens para melhor controle  
                - **Parâmetros**: Para iniciantes, use os valores padrão  
                  
                #### 🎮 Compatibilidade:  
                  
                Os LoRAs gerados são compatíveis com:  
                - **ComfyUI**: Carregue os arquivos .safetensors  
                - **Automatic1111**: Coloque na pasta models/Lora  
                - **SeaArt**: Faça upload do modelo  
                - **Outras ferramentas**: Qualquer ferramenta que suporte LoRA para Stable Diffusion  
                  
                ---  
                  
                **Desenvolvido com ❤️ para a comunidade de IA e arte digital**  
                """)  
        
        # Footer  
        gr.Markdown("""  
        ---  
        <div style="text-align: center; color: #666; font-size: 0.9em;">  
            🎨 LoRA Image Trainer v1.0 | Otimizado para Baixa GPU | Compatível com Dispositivos Móveis  
        </div>  
        """)  
    
    return interface

# Criar e configurar interface
if __name__ == "__main__":
    # Criar diretórios necessários
    os.makedirs("./lora_models", exist_ok=True)

    # Configurar interface  
    interface = create_gradio_interface()  
  
    # Lançar aplicação  
    interface.launch(  
        server_name="0.0.0.0",  
        server_port=7860,  
        share=False,  
        show_error=True,  
        quiet=False  
    )