AlaaAbbas commited on
Commit
293bbb2
·
1 Parent(s): 391a195

Add application file

Browse files
Files changed (3) hide show
  1. app.py +332 -0
  2. model (2).pth +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import torch.nn as nn
6
+ import torchvision.transforms as transforms
7
+
8
+ F = torch.nn.functional
9
+
10
+ class ConvLSTMCell(nn.Module):
11
+ def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):
12
+ super(ConvLSTMCell, self).__init__()
13
+
14
+ self.input_channels = input_channels
15
+ self.hidden_channels = hidden_channels
16
+ self.kernel_size = kernel_size
17
+
18
+ self.conv_ii = nn.Conv2d(
19
+ self.input_channels,
20
+ self.hidden_channels,
21
+ self.kernel_size,
22
+ padding=self.kernel_size // 2,
23
+ bias=bias,
24
+ )
25
+ self.conv_hi = nn.Conv2d(
26
+ self.hidden_channels,
27
+ self.hidden_channels,
28
+ self.kernel_size,
29
+ padding=self.kernel_size // 2,
30
+ bias=bias,
31
+ )
32
+
33
+ self.conv_if = nn.Conv2d(
34
+ self.input_channels,
35
+ self.hidden_channels,
36
+ self.kernel_size,
37
+ padding=self.kernel_size // 2,
38
+ bias=bias,
39
+ )
40
+ self.conv_hf = nn.Conv2d(
41
+ self.hidden_channels,
42
+ self.hidden_channels,
43
+ self.kernel_size,
44
+ padding=self.kernel_size // 2,
45
+ bias=bias,
46
+ )
47
+
48
+ self.conv_ig = nn.Conv2d(
49
+ self.input_channels,
50
+ self.hidden_channels,
51
+ self.kernel_size,
52
+ padding=self.kernel_size // 2,
53
+ bias=bias,
54
+ )
55
+ self.conv_hg = nn.Conv2d(
56
+ self.hidden_channels,
57
+ self.hidden_channels,
58
+ self.kernel_size,
59
+ padding=self.kernel_size // 2,
60
+ bias=bias,
61
+ )
62
+
63
+ self.conv_io = nn.Conv2d(
64
+ self.input_channels,
65
+ self.hidden_channels,
66
+ self.kernel_size,
67
+ padding=self.kernel_size // 2,
68
+ bias=bias,
69
+ )
70
+ self.conv_ho = nn.Conv2d(
71
+ self.hidden_channels,
72
+ self.hidden_channels,
73
+ self.kernel_size,
74
+ padding=self.kernel_size // 2,
75
+ bias=bias,
76
+ )
77
+
78
+ def forward(self, x, hidden_state):
79
+ h_prev, c_prev = hidden_state
80
+
81
+ i = torch.sigmoid(self.conv_ii(x) + self.conv_hi(h_prev))
82
+ f = torch.sigmoid(self.conv_if(x) + self.conv_hf(h_prev))
83
+ g = F.relu(self.conv_ig(x) + self.conv_hg(h_prev))
84
+ o = torch.sigmoid(self.conv_io(x) + self.conv_ho(h_prev))
85
+ c = f * c_prev + i * g
86
+ h = o * F.relu(c)
87
+
88
+ return h, c
89
+
90
+ class ConvLSTM(nn.Module):
91
+ def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):
92
+ super(ConvLSTM, self).__init__()
93
+
94
+ self.input_channels = input_channels
95
+ self.hidden_channels = hidden_channels
96
+
97
+ # Single ConvLSTM layer
98
+ self.conv_lstm_cell = ConvLSTMCell(
99
+ self.input_channels, self.hidden_channels, kernel_size, bias
100
+ )
101
+
102
+ def forward(self, x):
103
+ batch_size, channels, sequence_length, height, width = x.size()
104
+
105
+ # Initialize hidden state and cell state
106
+ h = torch.zeros(batch_size, self.hidden_channels, height, width).to(x.device)
107
+ c = torch.zeros(batch_size, self.hidden_channels, height, width).to(x.device)
108
+
109
+ outputs = list()
110
+
111
+ # Process each time step in the sequence
112
+ for t in range(sequence_length):
113
+ h, c = self.conv_lstm_cell(x[:, :, t, :, :], (h, c))
114
+ outputs.append(h)
115
+
116
+ outputs = torch.stack(outputs, dim=0).permute(1, 2, 0, 3, 4).contiguous()
117
+
118
+ return outputs
119
+
120
+ class NextFramePredictionModel(nn.Module):
121
+ def __init__(self):
122
+ super().__init__()
123
+ val = 256
124
+ self.convlstm0 = nn.Sequential(
125
+ ConvLSTM(3, val, 5), # Modified line
126
+ nn.BatchNorm3d(val),
127
+ )
128
+ self.convlstm1 = nn.Sequential(
129
+ ConvLSTM(val, val, 3),
130
+ nn.BatchNorm3d(val),
131
+ )
132
+ self.convlstm2 = nn.Sequential(
133
+ ConvLSTM(val, val, 1),
134
+ nn.BatchNorm3d(val),
135
+ )
136
+ self.final = ConvLSTM(val, 3, 1)
137
+
138
+ def forward(self, x):
139
+ x = self.convlstm0(x)
140
+ x = self.convlstm1(x)
141
+ x = self.convlstm2(x)
142
+
143
+ return self.final(x)
144
+
145
+ class ModelWrapper(nn.Module):
146
+ def __init__(self):
147
+ super().__init__()
148
+ self.arch = NextFramePredictionModel()
149
+
150
+
151
+ def forward(self, x):
152
+ return self.arch(x)
153
+
154
+ def preprocess_image(image):
155
+ """
156
+ Preprocesses the input image to be compatible with the model.
157
+
158
+ Args:
159
+ image_path (str): Path to the input image.
160
+
161
+ Returns:
162
+ torch.Tensor: Preprocessed image tensor.
163
+ """
164
+ transform = transforms.Compose([
165
+ transforms.Resize((256, 256)), # Default size as per training
166
+ transforms.ToTensor(),
167
+ #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
168
+ ])
169
+
170
+ image = Image.fromarray(image).convert('RGB')
171
+ image = transform(image)
172
+ image = image.unsqueeze(0) # Add batch dimension
173
+ image = image.permute(1, 0, 2, 3) # Rearrange dimensions as per training setup
174
+ return image.unsqueeze(0)
175
+
176
+
177
+ def preprocess_image_no_normalize(image_path: str):
178
+ """
179
+ Preprocesses the input image to be compatible with the model.
180
+
181
+ Args:
182
+ image_path (str): Path to the input image.
183
+
184
+ Returns:
185
+ torch.Tensor: Preprocessed image tensor.
186
+ """
187
+ transform = transforms.Compose([
188
+ transforms.Resize((256, 256)), # Default size as per training
189
+ transforms.ToTensor(),
190
+ ])
191
+
192
+ image = Image.open(image_path).convert("RGB")
193
+ image = transform(image)
194
+ image = image.unsqueeze(0) # Add batch dimension
195
+ image = image.permute(1, 0, 2, 3) # Rearrange dimensions as per training setup
196
+ return image.unsqueeze(0)
197
+
198
+
199
+ def denormalize_image(output_image: torch.Tensor):
200
+ """
201
+ Denormalizes the output image from model predictions.
202
+
203
+ Args:
204
+ output_image (torch.Tensor): The model's raw output image tensor in shape (H, W, C).
205
+
206
+ Returns:
207
+ torch.Tensor: The denormalized image tensor in shape (H, W, C).
208
+ """
209
+ # Check if the input image is in HWC format and convert to CHW format
210
+ if output_image.ndimension() == 3 and output_image.shape[2] == 3:
211
+ output_image = output_image.permute(2, 0, 1) # Convert to C x H x W format
212
+
213
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) # Shape (3, 1, 1)
214
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) # Shape (3, 1, 1)
215
+
216
+ # Reverse normalization: output_image * std + mean
217
+ denormalized_image = output_image * std + mean
218
+
219
+ # Convert back to HWC format for visualization
220
+ denormalized_image = denormalized_image.permute(1, 2, 0) # Convert back to H x W x C
221
+ return denormalized_image
222
+
223
+
224
+
225
+
226
+
227
+ def load_model(model_path: str, device: str):
228
+ """
229
+ Load the trained NextFramePredictionModel from the specified path.
230
+
231
+ Args:
232
+ model_path (str): Path to the saved model file (e.g., mode.pth).
233
+ device (str): Device to load the model on (e.g., 'cpu' or 'cuda').
234
+
235
+ Returns:
236
+ torch.nn.Module: The loaded model in evaluation mode.
237
+ """
238
+
239
+
240
+ # Initialize the model
241
+ model = ModelWrapper()
242
+ model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
243
+ model.eval() # Set the model to evaluation mode
244
+ return model
245
+
246
+
247
+ def calculate_time_steps(temperature: float, base_temperature: float = 25, Q10: float = 2):
248
+ """
249
+ Calculates the equivalent time steps needed based on the given temperature.
250
+
251
+ Args:
252
+ temperature (float): The current temperature.
253
+ base_temperature (float): The temperature for which the model is calibrated (default is 25).
254
+ Q10 (float): The Q10 coefficient (default is 2).
255
+
256
+ Returns:
257
+ int: The number of prediction steps needed.
258
+ """
259
+ k1 = 1 # Original spoilage rate at base_temperature (1 step per day at 25°C)
260
+ k2 = k1 * Q10 ** ((temperature - base_temperature) / 10)
261
+ return max(1, round(k2)) # Ensure at least 1 step
262
+
263
+
264
+ def predict_next_frame(image, model: torch.nn.Module, num_steps: int = 1):
265
+ """
266
+ Predicts the next frame(s) based on the input image and temperature-adjusted steps.
267
+
268
+ Args:
269
+ image_path (str): Path to the input image.
270
+ model (torch.nn.Module): Loaded PyTorch model.
271
+ num_steps (int): Number of prediction steps to perform.
272
+
273
+ Returns:
274
+ np.ndarray: Predicted frame as a NumPy array after `num_steps` iterations.
275
+ """
276
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
277
+ model.to(device)
278
+ model.eval()
279
+
280
+ # Preprocess the input image
281
+ input_tensor = preprocess_image(image).to(device)
282
+
283
+ # Iteratively predict the next frame
284
+ for _ in range(num_steps):
285
+ with torch.no_grad():
286
+ output_tensor = model(input_tensor)
287
+ # Update input_tensor for the next prediction
288
+ if _ == num_steps-1:
289
+ output_frame = output_tensor.permute(0, 2, 3, 4, 1)[0][0].detach().cpu().numpy()
290
+ input_tensor = output_tensor
291
+
292
+ # Postprocess the final output
293
+ return torch.tensor(output_frame) #denormalize_image(torch.tensor(output_frame))
294
+
295
+
296
+ def load_and_predict(image, temperature: float=25, model_path: str = 'model (2).pth'):
297
+ """
298
+ Loads the model, calculates time steps, and predicts the next frame for the given image and temperature.
299
+
300
+ Args:
301
+ image_path (str): Path to the input image.
302
+ temperature (float): The current temperature.
303
+ model_path (str): Path to the saved model file.
304
+
305
+ Returns:
306
+ np.ndarray: Predicted frame as a NumPy array.
307
+ """
308
+ # Determine the device
309
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
310
+
311
+ # Load the model
312
+ model = load_model(model_path, device)
313
+
314
+ # Calculate the number of steps based on temperature
315
+ num_steps = calculate_time_steps(temperature)
316
+ print(num_steps)
317
+ # Predict the next frame(s)
318
+ return predict_next_frame(image, model, num_steps=num_steps)
319
+
320
+
321
+
322
+ # Gradio interface
323
+ interface = gr.Interface(
324
+ fn=predict,
325
+ inputs=gr.Image(type="numpy"),
326
+ outputs="text",
327
+ title="Banana Predictor",
328
+ description="Upload image.",
329
+ )
330
+
331
+ if __name__ == "__main__":
332
+ interface.launch()
model (2).pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3d85e1e5bf3b1f05645c96233dea360c07880de846cb580ed5874121a66f7aa
3
+ size 47567427
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ Pillow