xen87348 commited on
Commit
56b5d8a
·
verified ·
1 Parent(s): f15cb9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -116
app.py CHANGED
@@ -1,64 +1,15 @@
1
- import gradio as gr
2
- import torch
3
- import spaces
4
- from PIL import Image
5
- from transformers import CLIPTokenizer, CLIPTextModel
6
- import numpy as np
7
- import os
8
- from typing import Literal
9
 
10
  # --- 1. CONFIGURATION AND MODEL PLACEHOLDERS ---
11
 
12
- # Define the list of styles for the Gradio dropdown
13
- STYLE_OPTIONS: list[str] = ["Photorealistic", "Impressionist", "Oil Painting", "Pixel Art"]
14
- # Dummy embeddings: in a real system, these would be loaded or calculated.
15
- # Using a 768-dim vector to match CLIP's output dimension.
16
- STYLE_EMBEDDINGS: dict[str, torch.Tensor] = {
17
- "Photorealistic": torch.zeros(768),
18
- "Impressionist": torch.ones(768) * 0.2,
19
- "Oil Painting": torch.ones(768) * 0.5,
20
- "Pixel Art": torch.ones(768) * 0.8,
21
- }
22
-
23
- class CustomTextEncoder:
24
- """Wrapper for the text encoder (using CLIP) to convert prompts to embeddings."""
25
- def __init__(self, device: str = "cuda"):
26
- # Load pre-trained CLIP components
27
- self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
28
- self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
29
- self.device = device
30
-
31
- def encode(self, prompt: str) -> torch.Tensor:
32
- """Converts text prompt into a single 768-dimensional embedding vector."""
33
- if not prompt:
34
- # Return a zero vector for empty prompts as negative conditioning
35
- return torch.zeros(1, 768, device=self.device)
36
-
37
- inputs = self.tokenizer(
38
- prompt,
39
- padding="max_length",
40
- max_length=self.tokenizer.model_max_length,
41
- truncation=True,
42
- return_tensors="pt"
43
- ).to(self.device)
44
-
45
- with torch.no_grad():
46
- # Get the pooled output for a single vector representing the entire text
47
- embeddings = self.text_model(**inputs).pooler_output
48
- return embeddings.to(torch.float32) # Ensure output is float32 for consistency
49
 
50
  class GANGenerator(torch.nn.Module):
51
  """
52
- Conditional GAN Generator Placeholder.
53
- This architecture uses a simple linear layer to simulate generation based on:
54
- 1. Noise vector (z)
55
- 2. Positive Text Embedding (c_pos)
56
- 3. Negative Text Embedding (c_neg)
57
- 4. Style Embedding (s_embed)
58
  """
59
  def __init__(self, latent_dim: int = 100, embed_dim: int = 768):
60
  super().__init__()
61
- # Total input dimension = Noise (100) + Positive (768) + Negative (768) + Style (768)
62
  input_dim = latent_dim + embed_dim * 3
63
 
64
  # Output: 3 color channels * 256 * 256 image size
@@ -68,33 +19,29 @@ class GANGenerator(torch.nn.Module):
68
 
69
  def forward(self, c_pos: torch.Tensor, c_neg: torch.Tensor, s_embed: torch.Tensor) -> torch.Tensor:
70
  batch_size = c_pos.shape[0]
71
- device = c_pos.device # Ensure noise is on the correct device
 
72
 
73
- # 1. Generate noise vector
74
  z = torch.randn(batch_size, self.latent_dim, device=device, dtype=torch.float32)
75
 
76
  # 2. Concatenate all conditioning inputs
77
  combined_conditioning = torch.cat([z, c_pos, c_neg, s_embed], dim=1)
78
 
79
- # 3. Simple feedforward pass (Placeholder for complex GAN layers)
80
  x = self.fc(combined_conditioning)
81
 
82
- # 4. Reshape to image format (Batch, Channels, Height, Width) and normalize to [-1, 1]
83
  image_tensor = x.view(batch_size, 3, 256, 256).tanh()
84
  return image_tensor.to(torch.float32)
85
 
86
 
87
- # --- 2. INITIALIZATION (Runs once on the Host/CPU, moves to GPU if available) ---
88
  DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
89
 
90
  try:
91
- # Initialize models and move them to the target device
92
  text_encoder = CustomTextEncoder(device=DEVICE)
93
  generator = GANGenerator().to(DEVICE).eval()
94
-
95
- # 📝 NOTE: If you have pre-trained weights, load them here:
96
- # generator.load_state_dict(torch.load("your_pretrained_weights.pth"))
97
-
98
  print(f"Models initialized on {DEVICE}")
99
  except Exception as e:
100
  print(f"Warning: Model initialization failed. Running with dummy data. Error: {e}")
@@ -108,75 +55,62 @@ def generate_image(positive_prompt: str, negative_prompt: str, style: str) -> Im
108
  """The main inference function, decorated for ZeroGPU."""
109
 
110
  if generator is None or text_encoder is None:
111
- # Fallback for failed initialization
112
  return Image.fromarray(np.zeros((256, 256, 3), dtype=np.uint8))
113
 
114
  # 1. Encode Inputs
115
  c_pos = text_encoder.encode(positive_prompt)
116
  c_neg = text_encoder.encode(negative_prompt)
117
 
118
- # Get style embedding and move it to the correct device
119
  s_embed = STYLE_EMBEDDINGS.get(style, STYLE_EMBEDDINGS["Photorealistic"]).to(DEVICE).unsqueeze(0)
120
 
121
- # Ensure all inputs are float32
122
  c_pos = c_pos.to(torch.float32)
123
  c_neg = c_neg.to(torch.float32)
124
  s_embed = s_embed.to(torch.float32)
125
 
126
- # 2. Generate Image (Forward Pass)
127
- with torch.no_grad():
128
- image_tensor = generator(c_pos, c_neg, s_embed)
129
-
130
- # 3. Post-process to PIL Image
131
- # Convert from [-1, 1] range to [0, 255]
132
- image_tensor = (image_tensor * 0.5 + 0.5) * 255.0
133
- image_tensor = image_tensor.clamp(0, 255).byte()
134
-
135
- # Convert from C H W to H W C (for numpy/PIL)
136
- image_numpy = image_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
137
-
138
- return Image.fromarray(image_numpy)
139
-
140
- # --- 4. GRADIO APP DEFINITION (FIXED for v6.x) ---
141
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
 
144
- # Removed 'theme' keyword argument from gr.Blocks() to fix the TypeError
145
- with gr.Blocks(
146
- title="Custom Text-to-Image ZeroGPU GAN"
147
- ) as demo:
148
- gr.Markdown("## ✨ Conditional GAN with Negative and Style Prompting")
149
- gr.Markdown("Enter a **Positive Description** (what you want) and an **Anti-Description** (what you *don't* want).")
150
-
151
- with gr.Row():
152
- positive_prompt = gr.Textbox(
153
- label="1. Positive Description",
154
- value="A beautiful, vibrant oil painting of a lighthouse by the sea",
155
- lines=2
156
- )
157
- style_dropdown = gr.Dropdown(
158
- label="3. Choose Style",
159
- choices=STYLE_OPTIONS,
160
- value=STYLE_OPTIONS[1],
161
- scale=0.5
162
- )
163
 
164
- negative_prompt = gr.Textbox(
165
- label="2. Anti-Description (Negative Prompt)",
166
- value="ugly, noise, blurry, low resolution, watermark, text",
167
- lines=2
168
- )
169
-
170
- generate_button = gr.Button("🎨 Generate Image", variant="primary")
171
-
172
- output_image = gr.Image(label="Generated Image (256x256)", type="pil", height=256)
173
-
174
- generate_button.click(
175
- fn=generate_image,
176
- inputs=[positive_prompt, negative_prompt, style_dropdown],
177
- outputs=output_image
178
- )
179
 
180
  if __name__ == "__main__":
181
- # The 'theme' argument is correctly placed in the .launch() call for Gradio 6.x
182
  demo.launch(theme=gr.themes.Soft())
 
1
+ # ... (Imports and STYLE_OPTIONS/STYLE_EMBEDDINGS are the same) ...
 
 
 
 
 
 
 
2
 
3
  # --- 1. CONFIGURATION AND MODEL PLACEHOLDERS ---
4
 
5
+ # ... (CustomTextEncoder class is the same) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class GANGenerator(torch.nn.Module):
8
  """
9
+ Conditional GAN Generator Placeholder with robust device handling.
 
 
 
 
 
10
  """
11
  def __init__(self, latent_dim: int = 100, embed_dim: int = 768):
12
  super().__init__()
 
13
  input_dim = latent_dim + embed_dim * 3
14
 
15
  # Output: 3 color channels * 256 * 256 image size
 
19
 
20
  def forward(self, c_pos: torch.Tensor, c_neg: torch.Tensor, s_embed: torch.Tensor) -> torch.Tensor:
21
  batch_size = c_pos.shape[0]
22
+ # Get the device from an input tensor (e.g., c_pos) to ensure consistency
23
+ device = c_pos.device
24
 
25
+ # ✅ FIX 1: Explicitly create the noise vector Z on the correct device
26
  z = torch.randn(batch_size, self.latent_dim, device=device, dtype=torch.float32)
27
 
28
  # 2. Concatenate all conditioning inputs
29
  combined_conditioning = torch.cat([z, c_pos, c_neg, s_embed], dim=1)
30
 
31
+ # 3. Feedforward pass (Placeholder)
32
  x = self.fc(combined_conditioning)
33
 
34
+ # 4. Reshape and normalize
35
  image_tensor = x.view(batch_size, 3, 256, 256).tanh()
36
  return image_tensor.to(torch.float32)
37
 
38
 
39
+ # --- 2. INITIALIZATION (Runs once on the Host/CPU) ---
40
  DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
41
 
42
  try:
 
43
  text_encoder = CustomTextEncoder(device=DEVICE)
44
  generator = GANGenerator().to(DEVICE).eval()
 
 
 
 
45
  print(f"Models initialized on {DEVICE}")
46
  except Exception as e:
47
  print(f"Warning: Model initialization failed. Running with dummy data. Error: {e}")
 
55
  """The main inference function, decorated for ZeroGPU."""
56
 
57
  if generator is None or text_encoder is None:
 
58
  return Image.fromarray(np.zeros((256, 256, 3), dtype=np.uint8))
59
 
60
  # 1. Encode Inputs
61
  c_pos = text_encoder.encode(positive_prompt)
62
  c_neg = text_encoder.encode(negative_prompt)
63
 
64
+ # FIX 2: Ensure style embedding is moved to the correct DEVICE
65
  s_embed = STYLE_EMBEDDINGS.get(style, STYLE_EMBEDDINGS["Photorealistic"]).to(DEVICE).unsqueeze(0)
66
 
67
+ # FIX 3: Explicitly cast all input tensors to float32 (standard for most GANs)
68
  c_pos = c_pos.to(torch.float32)
69
  c_neg = c_neg.to(torch.float32)
70
  s_embed = s_embed.to(torch.float32)
71
 
72
+ # --- DEBUGGING STEP: Check Shapes and Devices before generation ---
73
+ print("\n--- DEBUG INFO BEFORE GENERATION ---")
74
+ print(f"Generator device: {next(generator.parameters()).device}")
75
+ print(f"c_pos shape: {c_pos.shape}, device: {c_pos.device}")
76
+ print(f"c_neg shape: {c_neg.shape}, device: {c_neg.device}")
77
+ print(f"s_embed shape: {s_embed.shape}, device: {s_embed.device}")
78
+ print("------------------------------------\n")
79
+ # -----------------------------------------------------------------
80
+
81
+ try:
82
+ # 2. Generate Image (Forward Pass)
83
+ with torch.no_grad():
84
+ image_tensor = generator(c_pos, c_neg, s_embed)
85
+
86
+ # 3. Post-process to PIL Image (conversion code remains the same)
87
+ image_tensor = (image_tensor * 0.5 + 0.5) * 255.0
88
+ image_tensor = image_tensor.clamp(0, 255).byte()
89
+
90
+ # Convert from C H W to H W C (for numpy/PIL)
91
+ image_numpy = image_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
92
+
93
+ return Image.fromarray(image_numpy)
94
+
95
+ except RuntimeError as e:
96
+ # Catch and report the specific runtime error in the logs
97
+ print(f"\nFATAL RUNTIME ERROR DURING GENERATION: {e}\n")
98
+
99
+ if "out of memory" in str(e).lower():
100
+ # If it's OOM, suggest resolution reduction
101
+ error_message = "CUDA Out of Memory Error: The model is too large for the allocated ZeroGPU memory. Try reducing the output resolution (e.g., from 256x256 to 128x128) in the GANGenerator class."
102
+ else:
103
+ # Assume device/type mismatch for other RuntimeError cases
104
+ error_message = f"Runtime Error: Tensors or model parameters are likely on different devices (CPU/CUDA) or have mismatched data types (float32/float64). See logs for full traceback. Error: {e}"
105
+
106
+ # Return a red error image to the user
107
+ error_img = np.full((256, 256, 3), [255, 0, 0], dtype=np.uint8)
108
+ return Image.fromarray(error_img)
109
 
110
 
111
+ # --- 4. GRADIO APP DEFINITION (Same as before) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ # ... (The rest of the Gradio Blocks definition remains the same) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  if __name__ == "__main__":
 
116
  demo.launch(theme=gr.themes.Soft())