File size: 9,034 Bytes
0fad32a
 
bd4975f
e984a2c
 
b8975f7
b9a20a2
b8975f7
0fad32a
 
 
 
e984a2c
0fad32a
 
 
 
 
 
f3f26af
0fad32a
f3f26af
 
 
f5e7ff1
b8975f7
f5e7ff1
0fad32a
b8975f7
0fad32a
 
 
 
bd4975f
0fad32a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5e7ff1
0fad32a
f5e7ff1
 
0fad32a
b8975f7
 
 
1280c43
f3f26af
3cb91bd
 
0fad32a
f3f26af
b8975f7
 
 
0fad32a
f3f26af
b6a9837
0fad32a
 
 
 
 
 
 
 
 
 
 
 
 
1280c43
 
 
 
 
 
 
0fad32a
b8975f7
0fad32a
 
1280c43
 
b8975f7
3ce81f3
0fad32a
 
 
 
f5e7ff1
1280c43
 
 
 
 
 
 
 
 
0fad32a
 
b8975f7
0fad32a
 
 
 
 
 
 
 
 
 
 
b8975f7
0fad32a
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# app.py - Carrega base + UNet de repo privado separado
# Data e hora atuais para referência: Sunday, May 4, 2025 at 8:23:22 PM -03
import os, random, uuid, json
import gradio as gr
import numpy as np
from PIL import Image
import spaces
import torch
# Importar UNet também
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerAncestralDiscreteScheduler
import time
from huggingface_hub import HfApi

# --- Configurações ---
base_model_id = "sd-community/sdxl-flash" # Ou o base que o Space usava
# ID do Repositório PRIVADO que contém APENAS o UNet treinado
tuned_unet_repo_id = "borsojj/unet" # <<< Repo com seu UNet

DESCRIPTION = f"Interface usando base `{base_model_id}` com UNet de `{tuned_unet_repo_id}`."
if not torch.cuda.is_available():
    DESCRIPTION += "\n**Atenção:** Rodando em CPU 🥶 - A geração pode ser muito lenta ou falhar."

MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

print(f"Carregando pipeline base de: {base_model_id}...")
hf_token = os.getenv("HF_TOKEN") # Pega token dos segredos do Space
if hf_token:
    print("Segredo HF_TOKEN encontrado.")
else:
    # AVISO IMPORTANTE se o repo UNET for privado!
    print("AVISO: Segredo HF_TOKEN NÃO encontrado. O carregamento do UNet treinado falhará se o repositório for privado.")

start_time = time.time()
pipe = None
loading_error_message = ""
try:
    # 1. Carrega o pipeline base completo
    pipe = StableDiffusionXLPipeline.from_pretrained(
        base_model_id,
        torch_dtype=torch_dtype,
        use_safetensors=True,
        add_watermarker=False,
        token=hf_token # Passa token caso o base precise também
    )
    print(f"Pipeline base '{base_model_id}' carregado.")

    # 2. Tenta carregar e substituir o UNet do repo separado e PRIVADO
    if hf_token: # Só tenta carregar se houver token
        print(f"Tentando carregar UNet treinado do repo privado: {tuned_unet_repo_id}")
        try:
            # Carrega o UNet do repo ID, usando o token
            tuned_unet = UNet2DConditionModel.from_pretrained(
                tuned_unet_repo_id,
                torch_dtype=torch_dtype, # Carrega com mesmo dtype
                token=hf_token # Usa o token para acessar repo privado
            )
            print("UNet treinado carregado. Substituindo UNet no pipeline...")
            pipe.unet = tuned_unet # A SUBSTITUIÇÃO
            print("UNet substituído com sucesso.")
        except Exception as unet_load_e:
            loading_error_message = f"**<font color='red'>ERRO:</font>** Falha ao carregar UNet de `{tuned_unet_repo_id}` (Verifique token e repo). Usando UNet base. Erro: `{unet_load_e}`"
            print(loading_error_message)
            DESCRIPTION += "\n" + loading_error_message
            # Continua com o UNet base se falhar
    else:
         # Se não há token, não pode carregar UNet privado
         loading_error_message = f"**<font color='orange'>AVISO:</font>** HF_TOKEN não encontrado. Não é possível carregar UNet do repositório privado `{tuned_unet_repo_id}`. Usando UNet base."
         print(loading_error_message)
         DESCRIPTION += "\n" + loading_error_message

    # Configura o scheduler (como antes)
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
    print(f"Scheduler configurado para: {pipe.scheduler.__class__.__name__}")

    print(f"Movendo pipeline final para o device: {device}")
    pipe.to(device) # Move para o device APÓS substituir o UNet
    print("Pipeline pronto no device.")

except Exception as e:
    # Erro ao carregar o pipeline BASE
    print(f"Erro CRÍTICO ao carregar o pipeline base de '{base_model_id}': {e}")
    loading_error_message = f"**<font color='red'>ERRO CRÍTICO:</font>** Não foi possível carregar o pipeline base `{base_model_id}`. Erro: `{e}`."
    DESCRIPTION += "\n" + loading_error_message
    pipe = None

# Função generate (sem alterações significativas)
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed: seed = random.randint(0, MAX_SEED)
    return seed

@spaces.GPU(duration=90)
def generate(
    prompt: str,
    negative_prompt: str = "",
    use_negative_prompt: bool = False,
    seed: int = 0,
    width: int = 1024,
    height: int = 1024,
    guidance_scale: float = 7.0,
    num_inference_steps: int = 25,
    randomize_seed: bool = False,
    progress=gr.Progress(track_tqdm=True),
):
    if pipe is None: raise gr.Error(f"Pipeline não carregado. {loading_error_message}")
    pipe.to(device)
    seed = int(randomize_seed_fn(seed, randomize_seed))
    generator = torch.Generator(device=device).manual_seed(seed)
    if not use_negative_prompt: negative_prompt = None
    options = {"prompt":prompt, "negative_prompt":negative_prompt, "width":width, "height":height, "guidance_scale":guidance_scale, "num_inference_steps":num_inference_steps, "generator":generator, "output_type":"pil"}
    print(f"Gerando imagem com seed: {seed}, Steps: {num_inference_steps}, Guidance: {guidance_scale}")
    start_gen_time = time.time()
    try:
        images = pipe(**options).images
        print(f"Gerado {len(images)} imagem(s) em {time.time() - start_gen_time:.2f} segundos.")
        return images, seed
    except Exception as e:
        print(f"Erro durante a geração: {e}")
        raise gr.Error(f"Erro durante a geração: {e}")

# Interface Gradio
examples = [
    "photo of a futuristic city skyline at sunset, high detail",
    "an oil painting of a cute cat wearing a wizard hat",
    "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
    "An alien grasping a sign board contain word 'Flash', futuristic, neonpunk, detailed"
]
css = ".gradio-container { max-width: 800px !important; margin: 0 auto !important; } h1{ text-align:center }"
with gr.Blocks(css=css) as demo:
    gr.Markdown(f"""# SDXL Base com UNet Fine-tuned (`{tuned_unet_repo_id}`)
        Base: `{base_model_id}`
        {DESCRIPTION}
        **Aviso:** O filtro de conteúdo explícito foi desativado. Use prompts com cuidado.""")
    with gr.Group():
        with gr.Row():
            prompt = gr.Text(label="Prompt", show_label=False, max_lines=3, placeholder="Descreva a imagem...", container=False)
            run_button = gr.Button("Gerar Imagem", variant="primary", scale=0)
        result = gr.Gallery(label="Resultado", show_label=False, elem_id="gallery", columns=1, height=768)
    with gr.Accordion("Opções Avançadas", open=False):
        with gr.Row():
            use_negative_prompt = gr.Checkbox(label="Usar prompt negativo", value=False)
            negative_prompt = gr.Text(
                label="Prompt Negativo",
                max_lines=3,
                lines=2,
                placeholder="O que evitar na imagem (ex.: blurry, deformed)...",
                value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, blurry, amputation",
                visible=False
            )
        seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
        randomize_seed = gr.Checkbox(label="Seed Aleatória", value=True)
        with gr.Row():
            width = gr.Slider(label="Largura", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
            height = gr.Slider(label="Altura", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
        with gr.Row():
            guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.5, value=7.0)
            num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=25)
    gr.Examples(examples=examples, inputs=prompt, outputs=[result, seed], fn=generate, cache_examples=CACHE_EXAMPLES)
    use_negative_prompt.change(fn=lambda x: gr.update(visible=x), inputs=use_negative_prompt, outputs=negative_prompt, api_name=False)
    generate_inputs = [prompt, negative_prompt, use_negative_prompt, seed, width, height, guidance_scale, num_inference_steps, randomize_seed]
    gr.on(triggers=[prompt.submit, run_button.click], fn=generate, inputs=generate_inputs, outputs=[result, seed], api_name="generate_image")

# Lança a interface no ambiente Space
if __name__ == "__main__":
    if pipe is not None:
         print("Lançando interface Gradio no Space...")
         demo.queue().launch() # Importante para Spaces
    else:
         print("ERRO CRÍTICO: Pipeline não carregado ou UNet não substituído corretamente. Lançando UI de erro.")
         with gr.Blocks() as error_demo:
              gr.Markdown(f"# Erro ao Carregar Modelo\n{DESCRIPTION}")
         error_demo.queue().launch()