SaniaE commited on
Commit
1dc4910
·
verified ·
1 Parent(s): 0a7dd0a

added more debugging for model load

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -45,7 +45,12 @@ def load_model():
45
  checkpoint = torch.load(model_path, map_location=DEVICE)
46
 
47
  gen_model = Generator(z_dim=Z_DIM).to(DEVICE)
48
- gen_model.load_state_dict(checkpoint["gen_state_dict"], strict=False)
 
 
 
 
 
49
  gen_model.eval()
50
  print("SUCCESS: Petrol Pump GAN is live!")
51
  except Exception as e:
@@ -87,10 +92,7 @@ def generate_random():
87
  """Endpoint 1: Purely random generation for 'Discovery'."""
88
  if gen_model is None: raise HTTPException(status_code=503)
89
 
90
- with torch.inference_mode():
91
- seed = torch.seed()
92
- torch.manual_seed(seed)
93
-
94
  noise = torch.randn(1, Z_DIM, device=DEVICE)
95
  print("NOISE:", noise[0, :5])
96
  fake_img = gen_model(noise)
 
45
  checkpoint = torch.load(model_path, map_location=DEVICE)
46
 
47
  gen_model = Generator(z_dim=Z_DIM).to(DEVICE)
48
+ missing, unexpected = gen_model.load_state_dict(
49
+ checkpoint["gen_state_dict"], strict=False
50
+ )
51
+
52
+ print("Unexpected keys: ", unexpected)
53
+ print("Missing keys: ", missing)
54
  gen_model.eval()
55
  print("SUCCESS: Petrol Pump GAN is live!")
56
  except Exception as e:
 
92
  """Endpoint 1: Purely random generation for 'Discovery'."""
93
  if gen_model is None: raise HTTPException(status_code=503)
94
 
95
+ with torch.inference_mode():
 
 
 
96
  noise = torch.randn(1, Z_DIM, device=DEVICE)
97
  print("NOISE:", noise[0, :5])
98
  fake_img = gen_model(noise)