Fred808 commited on
Commit
eb28e94
·
verified ·
1 Parent(s): 9068030

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -20,6 +20,7 @@ from torchvision import models, transforms
20
  import matplotlib.pyplot as plt
21
  import seaborn as sns
22
  from collections import Counter
 
23
  import pickle
24
 
25
  # Set up logging
@@ -187,7 +188,6 @@ with open(CACHE_FILE, "wb") as f:
187
 
188
  logging.info("Incremental processing complete!")
189
 
190
- # Analyze image content using a pre-trained model
191
  def analyze_image(image_url):
192
  """Analyze image content using a pre-trained model."""
193
  if not image_url or not isinstance(image_url, str) or not image_url.startswith(('http://', 'https://')):
@@ -203,8 +203,13 @@ def analyze_image(image_url):
203
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
204
  ])
205
  image_tensor = preprocess(image).unsqueeze(0)
206
- model = models.resnet50(pretrained=True)
 
 
 
 
207
  model.eval()
 
208
  with torch.no_grad():
209
  output = model(image_tensor)
210
  return output
 
20
  import matplotlib.pyplot as plt
21
  import seaborn as sns
22
  from collections import Counter
23
+ from torchvision.models import ResNet50_Weights
24
  import pickle
25
 
26
  # Set up logging
 
188
 
189
  logging.info("Incremental processing complete!")
190
 
 
191
  def analyze_image(image_url):
192
  """Analyze image content using a pre-trained model."""
193
  if not image_url or not isinstance(image_url, str) or not image_url.startswith(('http://', 'https://')):
 
203
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
204
  ])
205
  image_tensor = preprocess(image).unsqueeze(0)
206
+
207
+ # Load ResNet50 weights from local cache
208
+ weights_path = "/app/models/resnet50-0676ba61.pth"
209
+ model = models.resnet50()
210
+ model.load_state_dict(torch.load(weights_path))
211
  model.eval()
212
+
213
  with torch.no_grad():
214
  output = model(image_tensor)
215
  return output