SaniaE commited on
Commit
7397034
·
verified ·
1 Parent(s): a510479

added more control for explore endpoint

Browse files
Files changed (1) hide show
  1. app.py +12 -15
app.py CHANGED
@@ -26,8 +26,6 @@ Z_DIM = 100
26
  DEVICE = torch.device("cpu")
27
  REPO_ID = "SaniaE/GeoGen"
28
  FILENAME = "dcgans_model_checkpoint.pt"
29
-
30
- # Global model variable
31
  gen_model = None
32
 
33
  @app.on_event("startup")
@@ -49,7 +47,7 @@ def load_model():
49
  gen_model = Generator(z_dim=Z_DIM).to(DEVICE)
50
  gen_model.load_state_dict(checkpoint["gen_state_dict"], strict=False)
51
  gen_model.eval()
52
- print("SUCCESS: Petrol Pump GAN is live!")
53
  except Exception as e:
54
  print(f"Error loading model: {e}")
55
 
@@ -59,10 +57,8 @@ def postprocess_image(tensor):
59
  img_tensor = (tensor + 1) / 2
60
  img_tensor = img_tensor.clamp(0, 1)
61
 
62
- # Use make_grid to handle single or batch images
63
  grid = vutils.make_grid(img_tensor, padding=0, normalize=False)
64
 
65
- # Convert to HWC format for PIL
66
  ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
67
  return Image.fromarray(ndarr)
68
 
@@ -92,31 +88,32 @@ def generate_random():
92
  if gen_model is None: raise HTTPException(status_code=503)
93
 
94
  with torch.inference_mode():
 
 
95
  noise = torch.randn(1, Z_DIM, device=DEVICE)
96
  fake_img = gen_model(noise)
97
  return StreamingResponse(get_image_stream(fake_img), media_type="image/png")
98
 
99
 
100
  @app.get("/explore")
101
- def explore_latent(
102
- seed: int,
103
- x_shift: float = Query(0.0, ge=-5.0, le=5.0),
104
- y_shift: float = Query(0.0, ge=-5.0, le=5.0)
105
- ):
106
  """Endpoint 2: Controlled generation for 'Tuning'."""
107
  if gen_model is None: raise HTTPException(status_code=503)
108
 
109
  try:
110
  with torch.inference_mode():
111
- # Use the seed to recreate the base 'personality' of the image
112
  torch.manual_seed(seed)
113
  if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
114
-
115
  noise = torch.randn(1, Z_DIM, device=DEVICE)
116
 
117
- # Apply shifts to specific dimensions
118
- noise[0, 0] += x_shift
119
- noise[0, 1] += y_shift
 
 
 
 
120
 
121
  fake_img = gen_model(noise)
122
  return StreamingResponse(get_image_stream(fake_img), media_type="image/png")
 
26
  DEVICE = torch.device("cpu")
27
  REPO_ID = "SaniaE/GeoGen"
28
  FILENAME = "dcgans_model_checkpoint.pt"
 
 
29
  gen_model = None
30
 
31
  @app.on_event("startup")
 
47
  gen_model = Generator(z_dim=Z_DIM).to(DEVICE)
48
  gen_model.load_state_dict(checkpoint["gen_state_dict"], strict=False)
49
  gen_model.eval()
50
+ print("SUCCESS: Petrol Pump GAN is live!")
51
  except Exception as e:
52
  print(f"Error loading model: {e}")
53
 
 
57
  img_tensor = (tensor + 1) / 2
58
  img_tensor = img_tensor.clamp(0, 1)
59
 
 
60
  grid = vutils.make_grid(img_tensor, padding=0, normalize=False)
61
 
 
62
  ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
63
  return Image.fromarray(ndarr)
64
 
 
88
  if gen_model is None: raise HTTPException(status_code=503)
89
 
90
  with torch.inference_mode():
91
+ torch.seed()
92
+
93
  noise = torch.randn(1, Z_DIM, device=DEVICE)
94
  fake_img = gen_model(noise)
95
  return StreamingResponse(get_image_stream(fake_img), media_type="image/png")
96
 
97
 
98
  @app.get("/explore")
99
+ def explore_latent(seed: int, x_shift: float = Query(0.0, ge=-5.0, le=5.0), y_shift: float = Query(0.0, ge=-5.0, le=5.0)):
 
 
 
 
100
  """Endpoint 2: Controlled generation for 'Tuning'."""
101
  if gen_model is None: raise HTTPException(status_code=503)
102
 
103
  try:
104
  with torch.inference_mode():
 
105
  torch.manual_seed(seed)
106
  if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
107
+
108
  noise = torch.randn(1, Z_DIM, device=DEVICE)
109
 
110
+ # Structured control
111
+ noise[:, :10] += x_shift
112
+ noise[:, 10:20] += y_shift
113
+
114
+ # Random direction
115
+ direction = torch.randn_like(noise)
116
+ noise = noise + 0.3 * direction * (abs(x_shift) + abs(y_shift))
117
 
118
  fake_img = gen_model(noise)
119
  return StreamingResponse(get_image_stream(fake_img), media_type="image/png")