HimankJ commited on
Commit
8db8e2d
·
verified ·
1 Parent(s): 83ccbef

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. utils/model_loader.py +28 -1
utils/model_loader.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  import json
6
  from torchvision import transforms
7
  import timm
 
8
 
9
  class ModelLoader:
10
  def __init__(self, bucket_name: str, model_name: str = "resnet18", num_classes: int = 13):
@@ -16,6 +17,9 @@ class ModelLoader:
16
  # Create directories if they don't exist
17
  os.makedirs('model', exist_ok=True)
18
 
 
 
 
19
  # Download and load model
20
  self.download_latest_model()
21
  self.model = self.load_model()
@@ -24,10 +28,33 @@ class ModelLoader:
24
  self.labels = self.get_labels()
25
  self.facts = self.get_facts()
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def download_latest_model(self):
28
  """Download the latest model from S3"""
29
- s3 = boto3.client('s3')
30
  try:
 
 
 
 
 
 
 
 
 
 
31
  s3.download_file(
32
  self.bucket_name,
33
  'latest/model.pt',
 
5
  import json
6
  from torchvision import transforms
7
  import timm
8
+ from pathlib import Path
9
 
10
  class ModelLoader:
11
  def __init__(self, bucket_name: str, model_name: str = "resnet18", num_classes: int = 13):
 
17
  # Create directories if they don't exist
18
  os.makedirs('model', exist_ok=True)
19
 
20
+ # Set AWS credentials path
21
+ self.aws_credentials_path = self.get_aws_credentials_path()
22
+
23
  # Download and load model
24
  self.download_latest_model()
25
  self.model = self.load_model()
 
28
  self.labels = self.get_labels()
29
  self.facts = self.get_facts()
30
 
31
+ def get_aws_credentials_path(self):
32
+ """Get the path to AWS credentials file"""
33
+ # Check current directory first
34
+ local_aws_dir = Path('.aws')
35
+ if local_aws_dir.exists():
36
+ return local_aws_dir
37
+
38
+ # Check home directory next
39
+ home_aws_dir = Path.home() / '.aws'
40
+ if home_aws_dir.exists():
41
+ return home_aws_dir
42
+
43
+ raise FileNotFoundError("AWS credentials directory not found")
44
+
45
  def download_latest_model(self):
46
  """Download the latest model from S3"""
 
47
  try:
48
+ # Create boto3 session with specific credentials file
49
+ session = boto3.Session(
50
+ profile_name='default',
51
+ aws_shared_credentials_file=str(self.aws_credentials_path / 'credentials')
52
+ )
53
+
54
+ # Create S3 client using the session
55
+ s3 = session.client('s3')
56
+
57
+ # Download the model
58
  s3.download_file(
59
  self.bucket_name,
60
  'latest/model.pt',