Spaces:
Runtime error
Runtime error
| # ... (Imports and STYLE_OPTIONS/STYLE_EMBEDDINGS are the same) ... | |
| # --- 1. CONFIGURATION AND MODEL PLACEHOLDERS --- | |
| # ... (CustomTextEncoder class is the same) ... | |
| class GANGenerator(torch.nn.Module): | |
| """ | |
| Conditional GAN Generator Placeholder with robust device handling. | |
| """ | |
| def __init__(self, latent_dim: int = 100, embed_dim: int = 768): | |
| super().__init__() | |
| input_dim = latent_dim + embed_dim * 3 | |
| # Output: 3 color channels * 256 * 256 image size | |
| self.output_pixels = 3 * 128 * 128 | |
| self.fc = torch.nn.Linear(input_dim, self.output_pixels) | |
| self.latent_dim = latent_dim | |
| def forward(self, c_pos: torch.Tensor, c_neg: torch.Tensor, s_embed: torch.Tensor) -> torch.Tensor: | |
| batch_size = c_pos.shape[0] | |
| # Get the device from an input tensor (e.g., c_pos) to ensure consistency | |
| device = c_pos.device | |
| # ✅ FIX 1: Explicitly create the noise vector Z on the correct device | |
| z = torch.randn(batch_size, self.latent_dim, device=device, dtype=torch.float32) | |
| # 2. Concatenate all conditioning inputs | |
| combined_conditioning = torch.cat([z, c_pos, c_neg, s_embed], dim=1) | |
| # 3. Feedforward pass (Placeholder) | |
| x = self.fc(combined_conditioning) | |
| # 4. Reshape and normalize | |
| image_tensor = x.view(batch_size, 3, 256, 256).tanh() | |
| return image_tensor.to(torch.float32) | |
| # --- 2. INITIALIZATION (Runs once on the Host/CPU) --- | |
| DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" | |
| try: | |
| text_encoder = CustomTextEncoder(device=DEVICE) | |
| generator = GANGenerator().to(DEVICE).eval() | |
| print(f"Models initialized on {DEVICE}") | |
| except Exception as e: | |
| print(f"Warning: Model initialization failed. Running with dummy data. Error: {e}") | |
| text_encoder = None | |
| generator = None | |
| # --- 3. ZERO GPU / GRADIO INTERFACE FUNCTION --- | |
| def generate_image(positive_prompt: str, negative_prompt: str, style: str) -> Image.Image: | |
| """The main inference function, decorated for ZeroGPU.""" | |
| if generator is None or text_encoder is None: | |
| return Image.fromarray(np.zeros((256, 256, 3), dtype=np.uint8)) | |
| # 1. Encode Inputs | |
| c_pos = text_encoder.encode(positive_prompt) | |
| c_neg = text_encoder.encode(negative_prompt) | |
| # ✅ FIX 2: Ensure style embedding is moved to the correct DEVICE | |
| s_embed = STYLE_EMBEDDINGS.get(style, STYLE_EMBEDDINGS["Photorealistic"]).to(DEVICE).unsqueeze(0) | |
| # ✅ FIX 3: Explicitly cast all input tensors to float32 (standard for most GANs) | |
| c_pos = c_pos.to(torch.float32) | |
| c_neg = c_neg.to(torch.float32) | |
| s_embed = s_embed.to(torch.float32) | |
| # --- DEBUGGING STEP: Check Shapes and Devices before generation --- | |
| print("\n--- DEBUG INFO BEFORE GENERATION ---") | |
| print(f"Generator device: {next(generator.parameters()).device}") | |
| print(f"c_pos shape: {c_pos.shape}, device: {c_pos.device}") | |
| print(f"c_neg shape: {c_neg.shape}, device: {c_neg.device}") | |
| print(f"s_embed shape: {s_embed.shape}, device: {s_embed.device}") | |
| print("------------------------------------\n") | |
| # ----------------------------------------------------------------- | |
| try: | |
| # 2. Generate Image (Forward Pass) | |
| with torch.no_grad(): | |
| image_tensor = generator(c_pos, c_neg, s_embed) | |
| # 3. Post-process to PIL Image (conversion code remains the same) | |
| image_tensor = (image_tensor * 0.5 + 0.5) * 255.0 | |
| image_tensor = image_tensor.clamp(0, 255).byte() | |
| # Convert from C H W to H W C (for numpy/PIL) | |
| image_numpy = image_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
| return Image.fromarray(image_numpy) | |
| except RuntimeError as e: | |
| # Catch and report the specific runtime error in the logs | |
| print(f"\nFATAL RUNTIME ERROR DURING GENERATION: {e}\n") | |
| if "out of memory" in str(e).lower(): | |
| # If it's OOM, suggest resolution reduction | |
| error_message = "CUDA Out of Memory Error: The model is too large for the allocated ZeroGPU memory. Try reducing the output resolution (e.g., from 256x256 to 128x128) in the GANGenerator class." | |
| else: | |
| # Assume device/type mismatch for other RuntimeError cases | |
| error_message = f"Runtime Error: Tensors or model parameters are likely on different devices (CPU/CUDA) or have mismatched data types (float32/float64). See logs for full traceback. Error: {e}" | |
| # Return a red error image to the user | |
| error_img = np.full((256, 256, 3), [255, 0, 0], dtype=np.uint8) | |
| return Image.fromarray(error_img) | |
| # --- 4. GRADIO APP DEFINITION (Same as before) --- | |
| # ... (The rest of the Gradio Blocks definition remains the same) ... | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Soft()) |