# ... (Imports and STYLE_OPTIONS/STYLE_EMBEDDINGS are the same) ... # --- 1. CONFIGURATION AND MODEL PLACEHOLDERS --- # ... (CustomTextEncoder class is the same) ... class GANGenerator(torch.nn.Module): """ Conditional GAN Generator Placeholder with robust device handling. """ def __init__(self, latent_dim: int = 100, embed_dim: int = 768): super().__init__() input_dim = latent_dim + embed_dim * 3 # Output: 3 color channels * 256 * 256 image size self.output_pixels = 3 * 128 * 128 self.fc = torch.nn.Linear(input_dim, self.output_pixels) self.latent_dim = latent_dim def forward(self, c_pos: torch.Tensor, c_neg: torch.Tensor, s_embed: torch.Tensor) -> torch.Tensor: batch_size = c_pos.shape[0] # Get the device from an input tensor (e.g., c_pos) to ensure consistency device = c_pos.device # ✅ FIX 1: Explicitly create the noise vector Z on the correct device z = torch.randn(batch_size, self.latent_dim, device=device, dtype=torch.float32) # 2. Concatenate all conditioning inputs combined_conditioning = torch.cat([z, c_pos, c_neg, s_embed], dim=1) # 3. Feedforward pass (Placeholder) x = self.fc(combined_conditioning) # 4. Reshape and normalize image_tensor = x.view(batch_size, 3, 256, 256).tanh() return image_tensor.to(torch.float32) # --- 2. INITIALIZATION (Runs once on the Host/CPU) --- DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" try: text_encoder = CustomTextEncoder(device=DEVICE) generator = GANGenerator().to(DEVICE).eval() print(f"Models initialized on {DEVICE}") except Exception as e: print(f"Warning: Model initialization failed. Running with dummy data. Error: {e}") text_encoder = None generator = None # --- 3. ZERO GPU / GRADIO INTERFACE FUNCTION --- @spaces.GPU(duration=30) def generate_image(positive_prompt: str, negative_prompt: str, style: str) -> Image.Image: """The main inference function, decorated for ZeroGPU.""" if generator is None or text_encoder is None: return Image.fromarray(np.zeros((256, 256, 3), dtype=np.uint8)) # 1. Encode Inputs c_pos = text_encoder.encode(positive_prompt) c_neg = text_encoder.encode(negative_prompt) # ✅ FIX 2: Ensure style embedding is moved to the correct DEVICE s_embed = STYLE_EMBEDDINGS.get(style, STYLE_EMBEDDINGS["Photorealistic"]).to(DEVICE).unsqueeze(0) # ✅ FIX 3: Explicitly cast all input tensors to float32 (standard for most GANs) c_pos = c_pos.to(torch.float32) c_neg = c_neg.to(torch.float32) s_embed = s_embed.to(torch.float32) # --- DEBUGGING STEP: Check Shapes and Devices before generation --- print("\n--- DEBUG INFO BEFORE GENERATION ---") print(f"Generator device: {next(generator.parameters()).device}") print(f"c_pos shape: {c_pos.shape}, device: {c_pos.device}") print(f"c_neg shape: {c_neg.shape}, device: {c_neg.device}") print(f"s_embed shape: {s_embed.shape}, device: {s_embed.device}") print("------------------------------------\n") # ----------------------------------------------------------------- try: # 2. Generate Image (Forward Pass) with torch.no_grad(): image_tensor = generator(c_pos, c_neg, s_embed) # 3. Post-process to PIL Image (conversion code remains the same) image_tensor = (image_tensor * 0.5 + 0.5) * 255.0 image_tensor = image_tensor.clamp(0, 255).byte() # Convert from C H W to H W C (for numpy/PIL) image_numpy = image_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() return Image.fromarray(image_numpy) except RuntimeError as e: # Catch and report the specific runtime error in the logs print(f"\nFATAL RUNTIME ERROR DURING GENERATION: {e}\n") if "out of memory" in str(e).lower(): # If it's OOM, suggest resolution reduction error_message = "CUDA Out of Memory Error: The model is too large for the allocated ZeroGPU memory. Try reducing the output resolution (e.g., from 256x256 to 128x128) in the GANGenerator class." else: # Assume device/type mismatch for other RuntimeError cases error_message = f"Runtime Error: Tensors or model parameters are likely on different devices (CPU/CUDA) or have mismatched data types (float32/float64). See logs for full traceback. Error: {e}" # Return a red error image to the user error_img = np.full((256, 256, 3), [255, 0, 0], dtype=np.uint8) return Image.fromarray(error_img) # --- 4. GRADIO APP DEFINITION (Same as before) --- # ... (The rest of the Gradio Blocks definition remains the same) ... if __name__ == "__main__": demo.launch(theme=gr.themes.Soft())