Spaces:
Sleeping
Sleeping
| # --- Fix 1: Set Matplotlib backend --- | |
| import matplotlib | |
| matplotlib.use('Agg') # Set backend BEFORE importing pyplot or other conflicting libs | |
| # --- End Fix 1 --- | |
| import gradio as gr | |
| import torch | |
| from diffusers import EulerAncestralDiscreteScheduler | |
| from DoodlePix_pipeline import StableDiffusionInstructPix2PixPipeline | |
| from PIL import Image, ImageOps # Added ImageOps for inversion | |
| import numpy as np | |
| import os | |
| import importlib | |
| import traceback # For detailed error printing | |
| # --- FidelityMLP Class (Ensure this is correct as provided by user) --- | |
| class FidelityMLP(torch.nn.Module): | |
| def __init__(self, hidden_size, output_size=None): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.output_size = output_size or hidden_size | |
| self.net = torch.nn.Sequential( | |
| torch.nn.Linear(1, 128), torch.nn.LayerNorm(128), torch.nn.SiLU(), | |
| torch.nn.Linear(128, 256), torch.nn.LayerNorm(256), torch.nn.SiLU(), | |
| torch.nn.Linear(256, hidden_size), torch.nn.LayerNorm(hidden_size), torch.nn.Tanh() | |
| ) | |
| self.output_proj = torch.nn.Linear(hidden_size, self.output_size) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| if isinstance(module, torch.nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=0.01) | |
| if module.bias is not None: module.bias.data.zero_() | |
| def forward(self, x, target_dim=None): | |
| features = self.net(x) | |
| outputs = self.output_proj(features) | |
| if target_dim is not None and target_dim != self.output_size: | |
| return self._adjust_dimension(outputs, target_dim) | |
| return outputs | |
| def _adjust_dimension(self, embeddings, target_dim): | |
| current_dim = embeddings.shape[-1] | |
| if target_dim > current_dim: | |
| pad_size = target_dim - current_dim | |
| padding = torch.zeros((*embeddings.shape[:-1], pad_size), device=embeddings.device, dtype=embeddings.dtype) | |
| return torch.cat([embeddings, padding], dim=-1) | |
| elif target_dim < current_dim: | |
| return embeddings[..., :target_dim] | |
| return embeddings | |
| def save_pretrained(self, save_directory): | |
| os.makedirs(save_directory, exist_ok=True) | |
| config = {"hidden_size": self.hidden_size, "output_size": self.output_size} | |
| torch.save(config, os.path.join(save_directory, "config.json")) | |
| torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) | |
| def from_pretrained(cls, pretrained_model_path): | |
| config_file = os.path.join(pretrained_model_path, "config.json") | |
| model_file = os.path.join(pretrained_model_path, "pytorch_model.bin") | |
| if not os.path.exists(config_file): raise FileNotFoundError(f"Config file not found at {config_file}") | |
| if not os.path.exists(model_file): raise FileNotFoundError(f"Model file not found at {model_file}") | |
| try: | |
| config = torch.load(config_file, map_location=torch.device('cpu')) | |
| if not isinstance(config, dict): raise TypeError(f"Expected config dict, got {type(config)}") | |
| except Exception as e: print(f"Error loading config {config_file}: {e}"); raise | |
| model = cls(hidden_size=config["hidden_size"], output_size=config.get("output_size", config["hidden_size"])) | |
| try: | |
| state_dict = torch.load(model_file, map_location=torch.device('cpu')) | |
| model.load_state_dict(state_dict) | |
| print(f"Successfully loaded FidelityMLP state dict from {model_file}") | |
| except Exception as e: print(f"Error loading state dict {model_file}: {e}"); raise | |
| return model | |
| # --- Global Variables --- | |
| pipeline = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_id = "Scaryplasmon96/DoodlePixV1" | |
| # --- Model Loading Function --- | |
| def load_pipeline(): | |
| global pipeline | |
| if pipeline is not None: return True | |
| print(f"Loading model {model_id} onto {device}...") | |
| try: | |
| hf_cache_dir = os.path.expanduser("~/.cache/huggingface/hub") | |
| local_model_path = model_id # Let diffusers find/download | |
| # Load Fidelity MLP if possible | |
| fidelity_mlp_instance = None | |
| try: | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| # Attempt to download config first to check existence | |
| hf_hub_download(repo_id=model_id, filename="fidelity_mlp/config.json", cache_dir=hf_cache_dir) | |
| # If config exists, download the whole subfolder | |
| fidelity_mlp_path = snapshot_download(repo_id=model_id, allow_patterns="fidelity_mlp/*", local_dir_use_symlinks=False, cache_dir=hf_cache_dir) | |
| fidelity_mlp_instance = FidelityMLP.from_pretrained(os.path.join(fidelity_mlp_path, "fidelity_mlp")) | |
| fidelity_mlp_instance = fidelity_mlp_instance.to(device=device, dtype=torch.float16) | |
| print("Fidelity MLP loaded successfully.") | |
| except Exception as e: | |
| print(f"Fidelity MLP not found or failed to load for {model_id}: {e}. Proceeding without MLP.") | |
| fidelity_mlp_instance = None | |
| scheduler = EulerAncestralDiscreteScheduler.from_pretrained(local_model_path, subfolder="scheduler") | |
| pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( | |
| local_model_path, torch_dtype=torch.float16, scheduler=scheduler, safety_checker=None | |
| ).to(device) | |
| if fidelity_mlp_instance: | |
| pipeline.fidelity_mlp = fidelity_mlp_instance | |
| print("Attached Fidelity MLP to pipeline.") | |
| # Optimizations | |
| if device == "cuda" and hasattr(pipeline, "enable_xformers_memory_efficient_attention"): | |
| try: pipeline.enable_xformers_memory_efficient_attention(); print("Enabled xformers.") | |
| except: print("Could not enable xformers. Using attention slicing."); pipeline.enable_attention_slicing() | |
| else: pipeline.enable_attention_slicing(); print("Enabled attention slicing.") | |
| print("Pipeline loaded successfully.") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading pipeline: {e}"); traceback.print_exc() | |
| pipeline = None; raise gr.Error(f"Failed to load model: {e}") | |
| # --- Image Generation Function (Corrected Input Handling) --- | |
| def generate_image(drawing_input, prompt, fidelity_slider, steps, guidance, image_guidance, seed_val): | |
| global pipeline | |
| if pipeline is None: | |
| if not load_pipeline(): return None, "Model not loaded. Check logs." | |
| # --- Corrected Input Processing --- | |
| print(f"DEBUG: Received drawing_input type: {type(drawing_input)}") | |
| if isinstance(drawing_input, dict): print(f"DEBUG: Received drawing_input keys: {drawing_input.keys()}") | |
| # Check if input is dict and get PIL image from 'composite' key | |
| if isinstance(drawing_input, dict) and "composite" in drawing_input and isinstance(drawing_input["composite"], Image.Image): | |
| input_image_pil = drawing_input["composite"].convert("RGB") # Get composite image | |
| print("DEBUG: Using PIL Image from 'composite' key.") | |
| else: | |
| err_msg = "Drawing input format unexpected. Expected dict with PIL Image under 'composite' key." | |
| print(f"ERROR: {err_msg} Input: {drawing_input}") | |
| return None, err_msg | |
| # --- End Corrected Input Processing --- | |
| try: | |
| # Invert the image: White bg -> Black bg, Black lines -> White lines | |
| input_image_inverted = ImageOps.invert(input_image_pil) | |
| #save the inverted image | |
| # input_image_inverted.save("input_image_inverted.png") | |
| # Ensure image is 512x512 | |
| if input_image_inverted.size != (512, 512): | |
| print(f"Resizing input image from {input_image_inverted.size} to (512, 512)") | |
| input_image_inverted = input_image_inverted.resize((512, 512), Image.Resampling.LANCZOS) | |
| # Prompt Construction | |
| final_prompt = f"f{int(fidelity_slider)}, {prompt}" | |
| if not final_prompt.endswith("background."): final_prompt += " background." | |
| negative_prompt = "artifacts, blur, jpg, uncanny, deformed, glow, shadow, text, words, letters, signature, watermark" | |
| # Generation | |
| print(f"Generating with: Prompt='{final_prompt[:100]}...', Fidelity={int(fidelity_slider)}, Steps={steps}, Guidance={guidance}, ImageGuidance={image_guidance}, Seed={seed_val}") | |
| seed_val = int(seed_val) | |
| generator = torch.Generator(device=device).manual_seed(seed_val) | |
| with torch.no_grad(): | |
| output = pipeline( | |
| prompt=final_prompt, negative_prompt=negative_prompt, image=input_image_inverted, | |
| num_inference_steps=int(steps), guidance_scale=float(guidance), | |
| image_guidance_scale=float(image_guidance), generator=generator, | |
| ).images[0] | |
| print("Generation complete.") | |
| return output, "Generation Complete" | |
| except Exception as e: | |
| print(f"Error during generation: {e}"); traceback.print_exc() | |
| return None, f"Error during generation: {str(e)}" | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", secondary_hue="blue")) as demo: | |
| gr.Markdown("# DoodlePix Gradio App") | |
| gr.Markdown(f"Using model: `{model_id}`.") | |
| status_output = gr.Textbox(label="Status", interactive=False, value="App loading...") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## 1. Draw Something (Black on White)") | |
| # Keep type="pil" as it provides the composite key | |
| drawing = gr.Sketchpad( | |
| label="Drawing Canvas", | |
| type="pil", # type="pil" gives dict output with 'composite' key | |
| height=512, width=512, | |
| brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=5), | |
| show_label=True | |
| ) | |
| prompt_input = gr.Textbox(label="2. Enter Prompt", placeholder="Describe the image you want...") | |
| fidelity = gr.Slider(0, 9, step=1, value=4, label="Fidelity (0=Creative, 9=Faithful)") | |
| num_steps = gr.Slider(10, 50, step=1, value=25, label="Inference Steps") | |
| guidance_scale = gr.Slider(1.0, 15.0, step=0.5, value=7.5, label="Guidance Scale (CFG)") | |
| image_guidance_scale = gr.Slider(0.5, 5.0, step=0.1, value=1.5, label="Image Guidance Scale") | |
| seed = gr.Number(label="Seed", value=42, precision=0) | |
| generate_button = gr.Button("🚀 Generate Image!", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("## 3. Generated Image") | |
| output_image = gr.Image(label="Result", type="pil", height=512, width=512, show_label=True) | |
| generate_button.click( | |
| fn=generate_image, | |
| inputs=[drawing, prompt_input, fidelity, num_steps, guidance_scale, image_guidance_scale, seed], | |
| outputs=[output_image, status_output] | |
| ) | |
| # --- Launch App --- | |
| if __name__ == "__main__": | |
| initial_status = "App loading..." | |
| print("Attempting to pre-load pipeline...") | |
| try: | |
| if load_pipeline(): initial_status = "Model pre-loaded successfully." | |
| else: initial_status = "Model pre-loading failed. Will retry on first generation." | |
| except Exception as e: | |
| print(f"Pre-loading failed: {e}") | |
| initial_status = f"Model pre-loading failed: {e}. Will retry on first generation." | |
| print(f"Pre-loading status: {initial_status}") | |
| demo.launch() |