phirni commited on
Commit
3d518d3
·
verified ·
1 Parent(s): 363e723

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -64,15 +64,24 @@ def predict_with_beta_vae(sequence, num_frames):
64
  def generate_predictions(frame_list, model_choice, mode_choice, num_frames):
65
  """
66
  Args:
67
- frame_list: uploaded images (list of PIL)
68
  model_choice: 'ConvLSTM' or 'β-VAE'
69
  mode_choice: 'Single Frame' or 'Multiple Frames'
70
  num_frames: number of consecutive frames to generate
71
  """
72
- if len(frame_list) < SEQUENCE_LENGTH:
73
  raise gr.Error(f"Please upload at least {SEQUENCE_LENGTH} sequential frames.")
74
 
75
- frames = frame_list[:SEQUENCE_LENGTH]
 
 
 
 
 
 
 
 
 
76
  processed = [preprocess_frame(f) for f in frames]
77
  sequence = torch.cat(processed, dim=0).unsqueeze(0) # (1, T, 1, H, W)
78
 
@@ -92,7 +101,6 @@ def generate_predictions(frame_list, model_choice, mode_choice, num_frames):
92
  description = """
93
  # 🕹️ Pong Frame Prediction
94
  Upload **10 sequential Pong frames** and select a model + prediction mode.
95
-
96
  - **ConvLSTM** → Learns temporal dynamics directly in pixel space
97
  - **β-VAE** → Predicts next frames via latent-space reconstruction
98
  """
@@ -105,9 +113,7 @@ demo = gr.Interface(
105
  gr.Radio(["Single Frame", "Multiple Frames"], label="Prediction Mode", value="Single Frame"),
106
  gr.Slider(1, 20, value=5, step=1, label="Number of Consecutive Frames (if Multiple Mode)"),
107
  ],
108
-
109
- outputs = gr.Gallery(label="Predicted Frames", elem_id="predicted-frames",show_label=True, columns=2), # columns replaces grid
110
-
111
  title="Pong Frame Predictor (ConvLSTM / β-VAE)",
112
  description=description,
113
  )
@@ -116,4 +122,4 @@ demo = gr.Interface(
116
  # Launch App
117
  # ===============================================================
118
  if __name__ == "__main__":
119
- demo.launch()
 
64
  def generate_predictions(frame_list, model_choice, mode_choice, num_frames):
65
  """
66
  Args:
67
+ frame_list: uploaded images (list of file paths as strings)
68
  model_choice: 'ConvLSTM' or 'β-VAE'
69
  mode_choice: 'Single Frame' or 'Multiple Frames'
70
  num_frames: number of consecutive frames to generate
71
  """
72
+ if frame_list is None or len(frame_list) < SEQUENCE_LENGTH:
73
  raise gr.Error(f"Please upload at least {SEQUENCE_LENGTH} sequential frames.")
74
 
75
+ # Convert file paths to PIL Images
76
+ frames = []
77
+ for file_path in frame_list[:SEQUENCE_LENGTH]:
78
+ # Handle both string paths and file objects
79
+ if isinstance(file_path, str):
80
+ img = Image.open(file_path)
81
+ else:
82
+ img = Image.open(file_path.name)
83
+ frames.append(img)
84
+
85
  processed = [preprocess_frame(f) for f in frames]
86
  sequence = torch.cat(processed, dim=0).unsqueeze(0) # (1, T, 1, H, W)
87
 
 
101
  description = """
102
  # 🕹️ Pong Frame Prediction
103
  Upload **10 sequential Pong frames** and select a model + prediction mode.
 
104
  - **ConvLSTM** → Learns temporal dynamics directly in pixel space
105
  - **β-VAE** → Predicts next frames via latent-space reconstruction
106
  """
 
113
  gr.Radio(["Single Frame", "Multiple Frames"], label="Prediction Mode", value="Single Frame"),
114
  gr.Slider(1, 20, value=5, step=1, label="Number of Consecutive Frames (if Multiple Mode)"),
115
  ],
116
+ outputs=gr.Gallery(label="Predicted Frames", elem_id="predicted-frames", show_label=True, columns=2),
 
 
117
  title="Pong Frame Predictor (ConvLSTM / β-VAE)",
118
  description=description,
119
  )
 
122
  # Launch App
123
  # ===============================================================
124
  if __name__ == "__main__":
125
+ demo.launch()