Spaces:
Runtime error
Runtime error
File size: 9,034 Bytes
0fad32a bd4975f e984a2c b8975f7 b9a20a2 b8975f7 0fad32a e984a2c 0fad32a f3f26af 0fad32a f3f26af f5e7ff1 b8975f7 f5e7ff1 0fad32a b8975f7 0fad32a bd4975f 0fad32a f5e7ff1 0fad32a f5e7ff1 0fad32a b8975f7 1280c43 f3f26af 3cb91bd 0fad32a f3f26af b8975f7 0fad32a f3f26af b6a9837 0fad32a 1280c43 0fad32a b8975f7 0fad32a 1280c43 b8975f7 3ce81f3 0fad32a f5e7ff1 1280c43 0fad32a b8975f7 0fad32a b8975f7 0fad32a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | # app.py - Carrega base + UNet de repo privado separado
# Data e hora atuais para referência: Sunday, May 4, 2025 at 8:23:22 PM -03
import os, random, uuid, json
import gradio as gr
import numpy as np
from PIL import Image
import spaces
import torch
# Importar UNet também
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerAncestralDiscreteScheduler
import time
from huggingface_hub import HfApi
# --- Configurações ---
base_model_id = "sd-community/sdxl-flash" # Ou o base que o Space usava
# ID do Repositório PRIVADO que contém APENAS o UNet treinado
tuned_unet_repo_id = "borsojj/unet" # <<< Repo com seu UNet
DESCRIPTION = f"Interface usando base `{base_model_id}` com UNet de `{tuned_unet_repo_id}`."
if not torch.cuda.is_available():
DESCRIPTION += "\n**Atenção:** Rodando em CPU 🥶 - A geração pode ser muito lenta ou falhar."
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"Carregando pipeline base de: {base_model_id}...")
hf_token = os.getenv("HF_TOKEN") # Pega token dos segredos do Space
if hf_token:
print("Segredo HF_TOKEN encontrado.")
else:
# AVISO IMPORTANTE se o repo UNET for privado!
print("AVISO: Segredo HF_TOKEN NÃO encontrado. O carregamento do UNet treinado falhará se o repositório for privado.")
start_time = time.time()
pipe = None
loading_error_message = ""
try:
# 1. Carrega o pipeline base completo
pipe = StableDiffusionXLPipeline.from_pretrained(
base_model_id,
torch_dtype=torch_dtype,
use_safetensors=True,
add_watermarker=False,
token=hf_token # Passa token caso o base precise também
)
print(f"Pipeline base '{base_model_id}' carregado.")
# 2. Tenta carregar e substituir o UNet do repo separado e PRIVADO
if hf_token: # Só tenta carregar se houver token
print(f"Tentando carregar UNet treinado do repo privado: {tuned_unet_repo_id}")
try:
# Carrega o UNet do repo ID, usando o token
tuned_unet = UNet2DConditionModel.from_pretrained(
tuned_unet_repo_id,
torch_dtype=torch_dtype, # Carrega com mesmo dtype
token=hf_token # Usa o token para acessar repo privado
)
print("UNet treinado carregado. Substituindo UNet no pipeline...")
pipe.unet = tuned_unet # A SUBSTITUIÇÃO
print("UNet substituído com sucesso.")
except Exception as unet_load_e:
loading_error_message = f"**<font color='red'>ERRO:</font>** Falha ao carregar UNet de `{tuned_unet_repo_id}` (Verifique token e repo). Usando UNet base. Erro: `{unet_load_e}`"
print(loading_error_message)
DESCRIPTION += "\n" + loading_error_message
# Continua com o UNet base se falhar
else:
# Se não há token, não pode carregar UNet privado
loading_error_message = f"**<font color='orange'>AVISO:</font>** HF_TOKEN não encontrado. Não é possível carregar UNet do repositório privado `{tuned_unet_repo_id}`. Usando UNet base."
print(loading_error_message)
DESCRIPTION += "\n" + loading_error_message
# Configura o scheduler (como antes)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
print(f"Scheduler configurado para: {pipe.scheduler.__class__.__name__}")
print(f"Movendo pipeline final para o device: {device}")
pipe.to(device) # Move para o device APÓS substituir o UNet
print("Pipeline pronto no device.")
except Exception as e:
# Erro ao carregar o pipeline BASE
print(f"Erro CRÍTICO ao carregar o pipeline base de '{base_model_id}': {e}")
loading_error_message = f"**<font color='red'>ERRO CRÍTICO:</font>** Não foi possível carregar o pipeline base `{base_model_id}`. Erro: `{e}`."
DESCRIPTION += "\n" + loading_error_message
pipe = None
# Função generate (sem alterações significativas)
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed: seed = random.randint(0, MAX_SEED)
return seed
@spaces.GPU(duration=90)
def generate(
prompt: str,
negative_prompt: str = "",
use_negative_prompt: bool = False,
seed: int = 0,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 7.0,
num_inference_steps: int = 25,
randomize_seed: bool = False,
progress=gr.Progress(track_tqdm=True),
):
if pipe is None: raise gr.Error(f"Pipeline não carregado. {loading_error_message}")
pipe.to(device)
seed = int(randomize_seed_fn(seed, randomize_seed))
generator = torch.Generator(device=device).manual_seed(seed)
if not use_negative_prompt: negative_prompt = None
options = {"prompt":prompt, "negative_prompt":negative_prompt, "width":width, "height":height, "guidance_scale":guidance_scale, "num_inference_steps":num_inference_steps, "generator":generator, "output_type":"pil"}
print(f"Gerando imagem com seed: {seed}, Steps: {num_inference_steps}, Guidance: {guidance_scale}")
start_gen_time = time.time()
try:
images = pipe(**options).images
print(f"Gerado {len(images)} imagem(s) em {time.time() - start_gen_time:.2f} segundos.")
return images, seed
except Exception as e:
print(f"Erro durante a geração: {e}")
raise gr.Error(f"Erro durante a geração: {e}")
# Interface Gradio
examples = [
"photo of a futuristic city skyline at sunset, high detail",
"an oil painting of a cute cat wearing a wizard hat",
"Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
"An alien grasping a sign board contain word 'Flash', futuristic, neonpunk, detailed"
]
css = ".gradio-container { max-width: 800px !important; margin: 0 auto !important; } h1{ text-align:center }"
with gr.Blocks(css=css) as demo:
gr.Markdown(f"""# SDXL Base com UNet Fine-tuned (`{tuned_unet_repo_id}`)
Base: `{base_model_id}`
{DESCRIPTION}
**Aviso:** O filtro de conteúdo explícito foi desativado. Use prompts com cuidado.""")
with gr.Group():
with gr.Row():
prompt = gr.Text(label="Prompt", show_label=False, max_lines=3, placeholder="Descreva a imagem...", container=False)
run_button = gr.Button("Gerar Imagem", variant="primary", scale=0)
result = gr.Gallery(label="Resultado", show_label=False, elem_id="gallery", columns=1, height=768)
with gr.Accordion("Opções Avançadas", open=False):
with gr.Row():
use_negative_prompt = gr.Checkbox(label="Usar prompt negativo", value=False)
negative_prompt = gr.Text(
label="Prompt Negativo",
max_lines=3,
lines=2,
placeholder="O que evitar na imagem (ex.: blurry, deformed)...",
value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, blurry, amputation",
visible=False
)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Seed Aleatória", value=True)
with gr.Row():
width = gr.Slider(label="Largura", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
height = gr.Slider(label="Altura", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
with gr.Row():
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.5, value=7.0)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=25)
gr.Examples(examples=examples, inputs=prompt, outputs=[result, seed], fn=generate, cache_examples=CACHE_EXAMPLES)
use_negative_prompt.change(fn=lambda x: gr.update(visible=x), inputs=use_negative_prompt, outputs=negative_prompt, api_name=False)
generate_inputs = [prompt, negative_prompt, use_negative_prompt, seed, width, height, guidance_scale, num_inference_steps, randomize_seed]
gr.on(triggers=[prompt.submit, run_button.click], fn=generate, inputs=generate_inputs, outputs=[result, seed], api_name="generate_image")
# Lança a interface no ambiente Space
if __name__ == "__main__":
if pipe is not None:
print("Lançando interface Gradio no Space...")
demo.queue().launch() # Importante para Spaces
else:
print("ERRO CRÍTICO: Pipeline não carregado ou UNet não substituído corretamente. Lançando UI de erro.")
with gr.Blocks() as error_demo:
gr.Markdown(f"# Erro ao Carregar Modelo\n{DESCRIPTION}")
error_demo.queue().launch() |