Spaces:
Runtime error
Runtime error
| # app.py - Carrega base + UNet de repo privado separado | |
| # Data e hora atuais para referência: Sunday, May 4, 2025 at 8:23:22 PM -03 | |
| import os, random, uuid, json | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import spaces | |
| import torch | |
| # Importar UNet também | |
| from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerAncestralDiscreteScheduler | |
| import time | |
| from huggingface_hub import HfApi | |
| # --- Configurações --- | |
| base_model_id = "sd-community/sdxl-flash" # Ou o base que o Space usava | |
| # ID do Repositório PRIVADO que contém APENAS o UNet treinado | |
| tuned_unet_repo_id = "borsojj/unet" # <<< Repo com seu UNet | |
| DESCRIPTION = f"Interface usando base `{base_model_id}` com UNet de `{tuned_unet_repo_id}`." | |
| if not torch.cuda.is_available(): | |
| DESCRIPTION += "\n**Atenção:** Rodando em CPU 🥶 - A geração pode ser muito lenta ou falhar." | |
| MAX_SEED = np.iinfo(np.int32).max | |
| CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1" | |
| MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096")) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| print(f"Carregando pipeline base de: {base_model_id}...") | |
| hf_token = os.getenv("HF_TOKEN") # Pega token dos segredos do Space | |
| if hf_token: | |
| print("Segredo HF_TOKEN encontrado.") | |
| else: | |
| # AVISO IMPORTANTE se o repo UNET for privado! | |
| print("AVISO: Segredo HF_TOKEN NÃO encontrado. O carregamento do UNet treinado falhará se o repositório for privado.") | |
| start_time = time.time() | |
| pipe = None | |
| loading_error_message = "" | |
| try: | |
| # 1. Carrega o pipeline base completo | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| base_model_id, | |
| torch_dtype=torch_dtype, | |
| use_safetensors=True, | |
| add_watermarker=False, | |
| token=hf_token # Passa token caso o base precise também | |
| ) | |
| print(f"Pipeline base '{base_model_id}' carregado.") | |
| # 2. Tenta carregar e substituir o UNet do repo separado e PRIVADO | |
| if hf_token: # Só tenta carregar se houver token | |
| print(f"Tentando carregar UNet treinado do repo privado: {tuned_unet_repo_id}") | |
| try: | |
| # Carrega o UNet do repo ID, usando o token | |
| tuned_unet = UNet2DConditionModel.from_pretrained( | |
| tuned_unet_repo_id, | |
| torch_dtype=torch_dtype, # Carrega com mesmo dtype | |
| token=hf_token # Usa o token para acessar repo privado | |
| ) | |
| print("UNet treinado carregado. Substituindo UNet no pipeline...") | |
| pipe.unet = tuned_unet # A SUBSTITUIÇÃO | |
| print("UNet substituído com sucesso.") | |
| except Exception as unet_load_e: | |
| loading_error_message = f"**<font color='red'>ERRO:</font>** Falha ao carregar UNet de `{tuned_unet_repo_id}` (Verifique token e repo). Usando UNet base. Erro: `{unet_load_e}`" | |
| print(loading_error_message) | |
| DESCRIPTION += "\n" + loading_error_message | |
| # Continua com o UNet base se falhar | |
| else: | |
| # Se não há token, não pode carregar UNet privado | |
| loading_error_message = f"**<font color='orange'>AVISO:</font>** HF_TOKEN não encontrado. Não é possível carregar UNet do repositório privado `{tuned_unet_repo_id}`. Usando UNet base." | |
| print(loading_error_message) | |
| DESCRIPTION += "\n" + loading_error_message | |
| # Configura o scheduler (como antes) | |
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
| print(f"Scheduler configurado para: {pipe.scheduler.__class__.__name__}") | |
| print(f"Movendo pipeline final para o device: {device}") | |
| pipe.to(device) # Move para o device APÓS substituir o UNet | |
| print("Pipeline pronto no device.") | |
| except Exception as e: | |
| # Erro ao carregar o pipeline BASE | |
| print(f"Erro CRÍTICO ao carregar o pipeline base de '{base_model_id}': {e}") | |
| loading_error_message = f"**<font color='red'>ERRO CRÍTICO:</font>** Não foi possível carregar o pipeline base `{base_model_id}`. Erro: `{e}`." | |
| DESCRIPTION += "\n" + loading_error_message | |
| pipe = None | |
| # Função generate (sem alterações significativas) | |
| def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
| if randomize_seed: seed = random.randint(0, MAX_SEED) | |
| return seed | |
| def generate( | |
| prompt: str, | |
| negative_prompt: str = "", | |
| use_negative_prompt: bool = False, | |
| seed: int = 0, | |
| width: int = 1024, | |
| height: int = 1024, | |
| guidance_scale: float = 7.0, | |
| num_inference_steps: int = 25, | |
| randomize_seed: bool = False, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if pipe is None: raise gr.Error(f"Pipeline não carregado. {loading_error_message}") | |
| pipe.to(device) | |
| seed = int(randomize_seed_fn(seed, randomize_seed)) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| if not use_negative_prompt: negative_prompt = None | |
| options = {"prompt":prompt, "negative_prompt":negative_prompt, "width":width, "height":height, "guidance_scale":guidance_scale, "num_inference_steps":num_inference_steps, "generator":generator, "output_type":"pil"} | |
| print(f"Gerando imagem com seed: {seed}, Steps: {num_inference_steps}, Guidance: {guidance_scale}") | |
| start_gen_time = time.time() | |
| try: | |
| images = pipe(**options).images | |
| print(f"Gerado {len(images)} imagem(s) em {time.time() - start_gen_time:.2f} segundos.") | |
| return images, seed | |
| except Exception as e: | |
| print(f"Erro durante a geração: {e}") | |
| raise gr.Error(f"Erro durante a geração: {e}") | |
| # Interface Gradio | |
| examples = [ | |
| "photo of a futuristic city skyline at sunset, high detail", | |
| "an oil painting of a cute cat wearing a wizard hat", | |
| "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k", | |
| "An alien grasping a sign board contain word 'Flash', futuristic, neonpunk, detailed" | |
| ] | |
| css = ".gradio-container { max-width: 800px !important; margin: 0 auto !important; } h1{ text-align:center }" | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown(f"""# SDXL Base com UNet Fine-tuned (`{tuned_unet_repo_id}`) | |
| Base: `{base_model_id}` | |
| {DESCRIPTION} | |
| **Aviso:** O filtro de conteúdo explícito foi desativado. Use prompts com cuidado.""") | |
| with gr.Group(): | |
| with gr.Row(): | |
| prompt = gr.Text(label="Prompt", show_label=False, max_lines=3, placeholder="Descreva a imagem...", container=False) | |
| run_button = gr.Button("Gerar Imagem", variant="primary", scale=0) | |
| result = gr.Gallery(label="Resultado", show_label=False, elem_id="gallery", columns=1, height=768) | |
| with gr.Accordion("Opções Avançadas", open=False): | |
| with gr.Row(): | |
| use_negative_prompt = gr.Checkbox(label="Usar prompt negativo", value=False) | |
| negative_prompt = gr.Text( | |
| label="Prompt Negativo", | |
| max_lines=3, | |
| lines=2, | |
| placeholder="O que evitar na imagem (ex.: blurry, deformed)...", | |
| value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, blurry, amputation", | |
| visible=False | |
| ) | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
| randomize_seed = gr.Checkbox(label="Seed Aleatória", value=True) | |
| with gr.Row(): | |
| width = gr.Slider(label="Largura", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024) | |
| height = gr.Slider(label="Altura", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.5, value=7.0) | |
| num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=25) | |
| gr.Examples(examples=examples, inputs=prompt, outputs=[result, seed], fn=generate, cache_examples=CACHE_EXAMPLES) | |
| use_negative_prompt.change(fn=lambda x: gr.update(visible=x), inputs=use_negative_prompt, outputs=negative_prompt, api_name=False) | |
| generate_inputs = [prompt, negative_prompt, use_negative_prompt, seed, width, height, guidance_scale, num_inference_steps, randomize_seed] | |
| gr.on(triggers=[prompt.submit, run_button.click], fn=generate, inputs=generate_inputs, outputs=[result, seed], api_name="generate_image") | |
| # Lança a interface no ambiente Space | |
| if __name__ == "__main__": | |
| if pipe is not None: | |
| print("Lançando interface Gradio no Space...") | |
| demo.queue().launch() # Importante para Spaces | |
| else: | |
| print("ERRO CRÍTICO: Pipeline não carregado ou UNet não substituído corretamente. Lançando UI de erro.") | |
| with gr.Blocks() as error_demo: | |
| gr.Markdown(f"# Erro ao Carregar Modelo\n{DESCRIPTION}") | |
| error_demo.queue().launch() |