xen87348's picture
Update app.py
c3f36d5 verified
# ... (Imports and STYLE_OPTIONS/STYLE_EMBEDDINGS are the same) ...
# --- 1. CONFIGURATION AND MODEL PLACEHOLDERS ---
# ... (CustomTextEncoder class is the same) ...
class GANGenerator(torch.nn.Module):
"""
Conditional GAN Generator Placeholder with robust device handling.
"""
def __init__(self, latent_dim: int = 100, embed_dim: int = 768):
super().__init__()
input_dim = latent_dim + embed_dim * 3
# Output: 3 color channels * 256 * 256 image size
self.output_pixels = 3 * 128 * 128
self.fc = torch.nn.Linear(input_dim, self.output_pixels)
self.latent_dim = latent_dim
def forward(self, c_pos: torch.Tensor, c_neg: torch.Tensor, s_embed: torch.Tensor) -> torch.Tensor:
batch_size = c_pos.shape[0]
# Get the device from an input tensor (e.g., c_pos) to ensure consistency
device = c_pos.device
# ✅ FIX 1: Explicitly create the noise vector Z on the correct device
z = torch.randn(batch_size, self.latent_dim, device=device, dtype=torch.float32)
# 2. Concatenate all conditioning inputs
combined_conditioning = torch.cat([z, c_pos, c_neg, s_embed], dim=1)
# 3. Feedforward pass (Placeholder)
x = self.fc(combined_conditioning)
# 4. Reshape and normalize
image_tensor = x.view(batch_size, 3, 256, 256).tanh()
return image_tensor.to(torch.float32)
# --- 2. INITIALIZATION (Runs once on the Host/CPU) ---
DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
try:
text_encoder = CustomTextEncoder(device=DEVICE)
generator = GANGenerator().to(DEVICE).eval()
print(f"Models initialized on {DEVICE}")
except Exception as e:
print(f"Warning: Model initialization failed. Running with dummy data. Error: {e}")
text_encoder = None
generator = None
# --- 3. ZERO GPU / GRADIO INTERFACE FUNCTION ---
@spaces.GPU(duration=30)
def generate_image(positive_prompt: str, negative_prompt: str, style: str) -> Image.Image:
"""The main inference function, decorated for ZeroGPU."""
if generator is None or text_encoder is None:
return Image.fromarray(np.zeros((256, 256, 3), dtype=np.uint8))
# 1. Encode Inputs
c_pos = text_encoder.encode(positive_prompt)
c_neg = text_encoder.encode(negative_prompt)
# ✅ FIX 2: Ensure style embedding is moved to the correct DEVICE
s_embed = STYLE_EMBEDDINGS.get(style, STYLE_EMBEDDINGS["Photorealistic"]).to(DEVICE).unsqueeze(0)
# ✅ FIX 3: Explicitly cast all input tensors to float32 (standard for most GANs)
c_pos = c_pos.to(torch.float32)
c_neg = c_neg.to(torch.float32)
s_embed = s_embed.to(torch.float32)
# --- DEBUGGING STEP: Check Shapes and Devices before generation ---
print("\n--- DEBUG INFO BEFORE GENERATION ---")
print(f"Generator device: {next(generator.parameters()).device}")
print(f"c_pos shape: {c_pos.shape}, device: {c_pos.device}")
print(f"c_neg shape: {c_neg.shape}, device: {c_neg.device}")
print(f"s_embed shape: {s_embed.shape}, device: {s_embed.device}")
print("------------------------------------\n")
# -----------------------------------------------------------------
try:
# 2. Generate Image (Forward Pass)
with torch.no_grad():
image_tensor = generator(c_pos, c_neg, s_embed)
# 3. Post-process to PIL Image (conversion code remains the same)
image_tensor = (image_tensor * 0.5 + 0.5) * 255.0
image_tensor = image_tensor.clamp(0, 255).byte()
# Convert from C H W to H W C (for numpy/PIL)
image_numpy = image_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
return Image.fromarray(image_numpy)
except RuntimeError as e:
# Catch and report the specific runtime error in the logs
print(f"\nFATAL RUNTIME ERROR DURING GENERATION: {e}\n")
if "out of memory" in str(e).lower():
# If it's OOM, suggest resolution reduction
error_message = "CUDA Out of Memory Error: The model is too large for the allocated ZeroGPU memory. Try reducing the output resolution (e.g., from 256x256 to 128x128) in the GANGenerator class."
else:
# Assume device/type mismatch for other RuntimeError cases
error_message = f"Runtime Error: Tensors or model parameters are likely on different devices (CPU/CUDA) or have mismatched data types (float32/float64). See logs for full traceback. Error: {e}"
# Return a red error image to the user
error_img = np.full((256, 256, 3), [255, 0, 0], dtype=np.uint8)
return Image.fromarray(error_img)
# --- 4. GRADIO APP DEFINITION (Same as before) ---
# ... (The rest of the Gradio Blocks definition remains the same) ...
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft())