phirni commited on
Commit
fbcd42e
·
verified ·
1 Parent(s): fab12b9

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +23 -21
inference.py CHANGED
@@ -2,32 +2,34 @@ import torch
2
  from torchvision import transforms
3
  from PIL import Image
4
  import numpy as np
5
- from model import ConvLSTMModel, BetaVAE
 
6
 
7
  # ===============================================================
8
- # Config
9
  # ===============================================================
10
  SEQUENCE_LENGTH = 10
11
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
 
13
  # ===============================================================
14
- # Load Models
15
  # ===============================================================
16
  def load_convlstm(path="convlstm_model.pth"):
17
- model = ConvLSTMModel()
18
  checkpoint = torch.load(path, map_location=DEVICE)
19
  model.load_state_dict(checkpoint)
20
  model.eval().to(DEVICE)
21
- print("✅ ConvLSTM model loaded.")
22
  return model
23
 
24
 
25
  def load_beta_vae(path="beta_vae_model.pth"):
26
- model = BetaVAE()
27
  checkpoint = torch.load(path, map_location=DEVICE)
28
  model.load_state_dict(checkpoint)
29
  model.eval().to(DEVICE)
30
- print("✅ β-VAE model loaded.")
31
  return model
32
 
33
 
@@ -35,7 +37,7 @@ def load_beta_vae(path="beta_vae_model.pth"):
35
  # Frame Pre/Post Processing
36
  # ===============================================================
37
  def preprocess_frame(frame: Image.Image):
38
- """Convert PIL image torch tensor (1,1,H,W) normalized to [0,1]."""
39
  transform = transforms.Compose([
40
  transforms.Grayscale(),
41
  transforms.Resize((64, 64)),
@@ -46,7 +48,7 @@ def preprocess_frame(frame: Image.Image):
46
 
47
 
48
  def postprocess_frame(tensor):
49
- """Convert torch tensor (1,1,H,W) → PIL image."""
50
  tensor = tensor.detach().cpu().clamp(0, 1)
51
  arr = tensor.squeeze().numpy() * 255
52
  arr = arr.astype(np.uint8)
@@ -54,33 +56,33 @@ def postprocess_frame(tensor):
54
 
55
 
56
  # ===============================================================
57
- # Inference Logic
58
  # ===============================================================
59
  @torch.no_grad()
60
- def predict_next_frame(model, sequence):
61
  """
 
62
  Args:
63
- model: ConvLSTMModel
64
  sequence: tensor (1, T, 1, H, W)
65
  Returns:
66
- PIL.Image
67
  """
68
- model.eval()
69
  sequence = sequence.to(DEVICE)
70
- next_frame = model(sequence) # (1,1,H,W)
71
  return postprocess_frame(next_frame)
72
 
73
 
74
  @torch.no_grad()
75
- def reconstruct_frame(model, frame):
76
  """
 
77
  Args:
78
- model: BetaVAE
79
- frame: tensor (1,1,H,W)
80
  Returns:
81
- PIL.Image
82
  """
83
- model.eval()
84
  frame = frame.to(DEVICE)
85
- recon, mu, logvar = model(frame)
86
  return postprocess_frame(recon)
 
2
  from torchvision import transforms
3
  from PIL import Image
4
  import numpy as np
5
+ from model import BetaVAE, ConvLSTM # your models
6
+ import torch.nn.functional as F
7
 
8
  # ===============================================================
9
+ # Configuration
10
  # ===============================================================
11
  SEQUENCE_LENGTH = 10
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
+
15
  # ===============================================================
16
+ # Model Loading
17
  # ===============================================================
18
  def load_convlstm(path="convlstm_model.pth"):
19
+ model = ConvLSTM(input_channels=1, hidden_channels=[64, 64, 64], output_channels=1)
20
  checkpoint = torch.load(path, map_location=DEVICE)
21
  model.load_state_dict(checkpoint)
22
  model.eval().to(DEVICE)
23
+ print("✅ ConvLSTM model loaded successfully.")
24
  return model
25
 
26
 
27
  def load_beta_vae(path="beta_vae_model.pth"):
28
+ model = BetaVAE(input_channels=1, latent_dim=64, beta=4.0)
29
  checkpoint = torch.load(path, map_location=DEVICE)
30
  model.load_state_dict(checkpoint)
31
  model.eval().to(DEVICE)
32
+ print("✅ β-VAE model loaded successfully.")
33
  return model
34
 
35
 
 
37
  # Frame Pre/Post Processing
38
  # ===============================================================
39
  def preprocess_frame(frame: Image.Image):
40
+ """Convert a PIL image to a normalized tensor (1, 1, 64, 64)."""
41
  transform = transforms.Compose([
42
  transforms.Grayscale(),
43
  transforms.Resize((64, 64)),
 
48
 
49
 
50
  def postprocess_frame(tensor):
51
+ """Convert tensor (1, 1, H, W) → PIL image."""
52
  tensor = tensor.detach().cpu().clamp(0, 1)
53
  arr = tensor.squeeze().numpy() * 255
54
  arr = arr.astype(np.uint8)
 
56
 
57
 
58
  # ===============================================================
59
+ # Inference Helpers
60
  # ===============================================================
61
  @torch.no_grad()
62
+ def predict_next_frame(convlstm_model, sequence):
63
  """
64
+ Predict the next frame using the ConvLSTM model.
65
  Args:
66
+ convlstm_model: trained ConvLSTM
67
  sequence: tensor (1, T, 1, H, W)
68
  Returns:
69
+ PIL.Image: predicted next frame
70
  """
 
71
  sequence = sequence.to(DEVICE)
72
+ next_frame = convlstm_model(sequence)
73
  return postprocess_frame(next_frame)
74
 
75
 
76
  @torch.no_grad()
77
+ def reconstruct_frame(beta_vae_model, frame):
78
  """
79
+ Reconstruct a single frame using the β-VAE.
80
  Args:
81
+ beta_vae_model: trained β-VAE
82
+ frame: tensor (1, 1, H, W)
83
  Returns:
84
+ PIL.Image: reconstructed frame
85
  """
 
86
  frame = frame.to(DEVICE)
87
+ recon, mu, logvar = beta_vae_model(frame)
88
  return postprocess_frame(recon)