DaniKaEp commited on
Commit
3dc57ed
·
verified ·
1 Parent(s): 6d53984

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -19,8 +19,7 @@ class_weights = torch.load("class_weights_period.pt")
19
  checkpoint_path = 'epoch=22-step=213621.ckpt'
20
  vae_model = VAE.load_from_checkpoint(checkpoint_path,image_channels=1,z_dim=12, lr =0.0001, use_classification_loss=True, num_classes=num_classes,
21
  loss_type="weighted", class_weights=class_weights, device = device)
22
- model.load_state_dict(torch.load('epoch=22-step=213621.ckpt'))
23
- model.eval()
24
 
25
  # Load your dataframe encoding
26
  df_encodings = pd.read_csv('df_vae_encoding_April16_all.csv')
@@ -37,10 +36,10 @@ def generate_image(period1, period2, interpolation_value):
37
 
38
  i = interpolation_value
39
  new_tablet = (1-i) * image1 + i * image2
40
- new_tab_long = model.fc3(new_tablet).unsqueeze(0)
41
 
42
  with torch.no_grad():
43
- generated_image = model.decoder(new_tab_long)
44
  generated_image = generated_image[0][0].detach().cpu().numpy()
45
  generated_image = (generated_image * 255).astype(np.uint8)
46
  pil_img = PILImage.fromarray(generated_image)
 
19
  checkpoint_path = 'epoch=22-step=213621.ckpt'
20
  vae_model = VAE.load_from_checkpoint(checkpoint_path,image_channels=1,z_dim=12, lr =0.0001, use_classification_loss=True, num_classes=num_classes,
21
  loss_type="weighted", class_weights=class_weights, device = device)
22
+ vae_model.eval()
 
23
 
24
  # Load your dataframe encoding
25
  df_encodings = pd.read_csv('df_vae_encoding_April16_all.csv')
 
36
 
37
  i = interpolation_value
38
  new_tablet = (1-i) * image1 + i * image2
39
+ new_tab_long = vae_model.fc3(new_tablet).unsqueeze(0)
40
 
41
  with torch.no_grad():
42
+ generated_image = vae_model.decoder(new_tab_long)
43
  generated_image = generated_image[0][0].detach().cpu().numpy()
44
  generated_image = (generated_image * 255).astype(np.uint8)
45
  pil_img = PILImage.fromarray(generated_image)