dixisouls commited on
Commit
ce43c34
·
1 Parent(s): 91a5e40
Files changed (4) hide show
  1. Dockerfile +10 -4
  2. app.py +54 -16
  3. app/download_resnet.py +42 -0
  4. app/image_captioning_service.py +70 -10
Dockerfile CHANGED
@@ -22,11 +22,14 @@ RUN mkdir -p app/models && chmod 777 app/models
22
  COPY app ./app
23
  COPY app.py .
24
 
25
- # Create NLTK data directory with proper permissions
26
- RUN mkdir -p /usr/local/share/nltk_data && chmod 777 /usr/local/share/nltk_data
 
 
27
 
28
- # Set NLTK_DATA environment variable
29
- ENV NLTK_DATA=/usr/local/share/nltk_data
 
30
 
31
  # Download NLTK data with explicit directory
32
  RUN python -c "import nltk; nltk.download('punkt', download_dir='/usr/local/share/nltk_data')"
@@ -34,6 +37,9 @@ RUN python -c "import nltk; nltk.download('punkt', download_dir='/usr/local/shar
34
  # Download model files during build
35
  RUN python -m app.download_model
36
 
 
 
 
37
  # Expose port
38
  EXPOSE 7860
39
 
 
22
  COPY app ./app
23
  COPY app.py .
24
 
25
+ # Create cache directories with proper permissions
26
+ RUN mkdir -p /.cache && chmod 777 /.cache
27
+ RUN mkdir -p /root/.cache/torch && chmod -R 777 /root/.cache
28
+ RUN mkdir -p /home/.cache/torch && chmod -R 777 /home/.cache
29
 
30
+ # Set PyTorch cache environment variable
31
+ ENV TORCH_HOME=/home/.cache/torch
32
+ ENV TRANSFORMERS_CACHE=/home/.cache/transformers
33
 
34
  # Download NLTK data with explicit directory
35
  RUN python -c "import nltk; nltk.download('punkt', download_dir='/usr/local/share/nltk_data')"
 
37
  # Download model files during build
38
  RUN python -m app.download_model
39
 
40
+ # Download ResNet50 model to avoid permission issues at runtime
41
+ RUN python -m app.download_resnet
42
+
43
  # Expose port
44
  EXPOSE 7860
45
 
app.py CHANGED
@@ -1,19 +1,3 @@
1
- """
2
- Main application entry point for Image Captioning API
3
- """
4
- import os
5
- import sys
6
- import logging
7
- import nltk
8
-
9
- # Configure logging
10
- logging.basicConfig(
11
- level=logging.INFO,
12
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
13
- )
14
- logger = logging.getLogger(__name__)
15
-
16
- # Setup NLTK data path
17
  def setup_nltk():
18
  """Set up NLTK data directory and ensure punkt tokenizer is available"""
19
  logger.info("Setting up NLTK...")
@@ -51,6 +35,57 @@ def setup_nltk():
51
  # If we get here, we couldn't download punkt anywhere
52
  logger.error("Could not download NLTK punkt tokenizer to any location")
53
  logger.error("The application may not function correctly")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # Check if model files exist and download if needed
56
  def ensure_models_exist():
@@ -65,6 +100,9 @@ def ensure_models_exist():
65
  logger.info("Model files found.")
66
 
67
  if __name__ == "__main__":
 
 
 
68
  # Setup NLTK
69
  setup_nltk()
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def setup_nltk():
2
  """Set up NLTK data directory and ensure punkt tokenizer is available"""
3
  logger.info("Setting up NLTK...")
 
35
  # If we get here, we couldn't download punkt anywhere
36
  logger.error("Could not download NLTK punkt tokenizer to any location")
37
  logger.error("The application may not function correctly")
38
+ """
39
+ Main application entry point for Image Captioning API
40
+ """
41
+ import os
42
+ import sys
43
+ import logging
44
+ import nltk
45
+
46
+ # Configure logging
47
+ logging.basicConfig(
48
+ level=logging.INFO,
49
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
50
+ )
51
+ logger = logging.getLogger(__name__)
52
+
53
+ # Setup NLTK data path
54
+ def setup_cache_directories():
55
+ """Create and set up cache directories for PyTorch and other libraries"""
56
+ cache_dirs = [
57
+ '/.cache',
58
+ '/root/.cache',
59
+ '/root/.cache/torch',
60
+ '/home/.cache',
61
+ '/home/.cache/torch',
62
+ '/tmp/.cache',
63
+ '/tmp/.cache/torch'
64
+ ]
65
+
66
+ for directory in cache_dirs:
67
+ try:
68
+ os.makedirs(directory, exist_ok=True)
69
+ # Try to set permissions
70
+ try:
71
+ os.chmod(directory, 0o777)
72
+ logger.info(f"Created cache directory with permissions: {directory}")
73
+ except Exception as e:
74
+ logger.warning(f"Could not set permissions for {directory}: {e}")
75
+ except Exception as e:
76
+ logger.warning(f"Could not create cache directory {directory}: {e}")
77
+
78
+ # Try setting environment variables for torch home
79
+ for cache_dir in ['/home/.cache/torch', '/tmp/.cache/torch', './torch_cache']:
80
+ try:
81
+ os.makedirs(cache_dir, exist_ok=True)
82
+ os.environ['TORCH_HOME'] = cache_dir
83
+ logger.info(f"Set TORCH_HOME to {cache_dir}")
84
+ break
85
+ except Exception as e:
86
+ logger.warning(f"Could not use {cache_dir} as TORCH_HOME: {e}")
87
+
88
+ logger.info(f"TORCH_HOME is set to: {os.environ.get('TORCH_HOME', 'Not set')}")
89
 
90
  # Check if model files exist and download if needed
91
  def ensure_models_exist():
 
100
  logger.info("Model files found.")
101
 
102
  if __name__ == "__main__":
103
+ # Setup cache directories
104
+ setup_cache_directories()
105
+
106
  # Setup NLTK
107
  setup_nltk()
108
 
app/download_resnet.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to download ResNet50 model and save it locally to avoid
4
+ permission issues when downloading at runtime.
5
+ """
6
+
7
+ import os
8
+ import torch
9
+ import torchvision.models as models
10
+ import logging
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ def download_resnet():
17
+ """Download ResNet50 model and save it to app/models/resnet50.pth"""
18
+ logger.info("Downloading ResNet50 model...")
19
+
20
+ # Create models directory if it doesn't exist
21
+ os.makedirs("app/models", exist_ok=True)
22
+
23
+ # Create torch cache directory with proper permissions
24
+ os.makedirs("/tmp/torch_cache", exist_ok=True)
25
+ os.environ["TORCH_HOME"] = "/tmp/torch_cache"
26
+
27
+ try:
28
+ # Load the model
29
+ model = models.resnet50(pretrained=True)
30
+
31
+ # Save the model
32
+ output_path = "app/models/resnet50.pth"
33
+ torch.save(model.state_dict(), output_path)
34
+
35
+ logger.info(f"ResNet50 model saved to {output_path}")
36
+ return True
37
+ except Exception as e:
38
+ logger.error(f"Error downloading ResNet50 model: {e}")
39
+ return False
40
+
41
+ if __name__ == "__main__":
42
+ download_resnet()
app/image_captioning_service.py CHANGED
@@ -123,7 +123,46 @@ class EncoderCNN(torch.nn.Module):
123
  super(EncoderCNN, self).__init__()
124
  # Load pretrained ResNet
125
  import torchvision.models as models
126
- resnet = models.resnet50(pretrained=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  # Remove the final FC layer
128
  modules = list(resnet.children())[:-1]
129
  self.resnet = torch.nn.Sequential(*modules)
@@ -324,6 +363,14 @@ def generate_caption(
324
  if not os.path.exists(vocab_path):
325
  raise FileNotFoundError(f"Vocabulary not found at {vocab_path}")
326
 
 
 
 
 
 
 
 
 
327
  # Load vocabulary
328
  logger.info(f"Loading vocabulary from {vocab_path}")
329
  vocab = Vocabulary.load(vocab_path)
@@ -348,18 +395,31 @@ def generate_caption(
348
 
349
  # Load model weights
350
  logger.info(f"Loading model weights from {model_path}")
351
- checkpoint = torch.load(model_path, map_location=device)
352
- model.load_state_dict(checkpoint['model_state_dict'])
353
- model.eval()
 
 
 
 
 
354
 
355
  # Load and process image
356
  logger.info(f"Loading and processing image from {image_path}")
357
- image = load_image(image_path)
358
- image = image.to(device)
 
 
 
 
 
359
 
360
  # Generate caption
361
  logger.info("Generating caption")
362
- caption = model.generate_caption(image, vocab, max_length=max_length)
363
- logger.info(f"Generated caption: {caption}")
364
-
365
- return caption
 
 
 
 
123
  super(EncoderCNN, self).__init__()
124
  # Load pretrained ResNet
125
  import torchvision.models as models
126
+
127
+ # Try different approaches to load ResNet50
128
+ resnet = None
129
+
130
+ # Option 1: Try to load the locally saved model
131
+ try:
132
+ logger.info("Trying to load locally saved ResNet50 model...")
133
+ resnet = models.resnet50(pretrained=False)
134
+ local_model_path = "app/models/resnet50.pth"
135
+ if os.path.exists(local_model_path):
136
+ resnet.load_state_dict(torch.load(local_model_path))
137
+ logger.info("Successfully loaded ResNet50 from local file")
138
+ else:
139
+ logger.warning(f"Local ResNet50 model not found at {local_model_path}")
140
+ # Fall back to pretrained model
141
+ resnet = None
142
+ except Exception as e:
143
+ logger.warning(f"Error loading local ResNet50 model: {str(e)}")
144
+ resnet = None
145
+
146
+ # Option 2: Try loading with pretrained weights
147
+ if resnet is None:
148
+ try:
149
+ logger.info("Trying to load ResNet50 with pretrained weights...")
150
+ # Set cache directory
151
+ os.makedirs('/tmp/torch_cache', exist_ok=True)
152
+ os.environ['TORCH_HOME'] = '/tmp/torch_cache'
153
+
154
+ resnet = models.resnet50(pretrained=True)
155
+ logger.info("Successfully loaded pretrained ResNet50 model")
156
+ except Exception as e:
157
+ logger.warning(f"Error loading pretrained ResNet50: {str(e)}")
158
+ resnet = None
159
+
160
+ # Option 3: Fall back to model without pretrained weights
161
+ if resnet is None:
162
+ logger.info("Falling back to ResNet50 without pretrained weights...")
163
+ resnet = models.resnet50(pretrained=False)
164
+ logger.warning("Using ResNet50 WITHOUT pretrained weights - captions may be less accurate")
165
+
166
  # Remove the final FC layer
167
  modules = list(resnet.children())[:-1]
168
  self.resnet = torch.nn.Sequential(*modules)
 
363
  if not os.path.exists(vocab_path):
364
  raise FileNotFoundError(f"Vocabulary not found at {vocab_path}")
365
 
366
+ # Setup temporary cache directory for torch if needed
367
+ try:
368
+ os.makedirs('/tmp/torch_cache', exist_ok=True)
369
+ os.environ['TORCH_HOME'] = '/tmp/torch_cache'
370
+ logger.info(f"Set TORCH_HOME to /tmp/torch_cache")
371
+ except Exception as e:
372
+ logger.warning(f"Could not set up temporary torch cache: {e}")
373
+
374
  # Load vocabulary
375
  logger.info(f"Loading vocabulary from {vocab_path}")
376
  vocab = Vocabulary.load(vocab_path)
 
395
 
396
  # Load model weights
397
  logger.info(f"Loading model weights from {model_path}")
398
+ try:
399
+ checkpoint = torch.load(model_path, map_location=device)
400
+ model.load_state_dict(checkpoint['model_state_dict'])
401
+ model.eval()
402
+ logger.info("Model loaded successfully")
403
+ except Exception as e:
404
+ logger.error(f"Error loading model: {str(e)}")
405
+ raise
406
 
407
  # Load and process image
408
  logger.info(f"Loading and processing image from {image_path}")
409
+ try:
410
+ image = load_image(image_path)
411
+ image = image.to(device)
412
+ logger.info("Image processed successfully")
413
+ except Exception as e:
414
+ logger.error(f"Error processing image: {str(e)}")
415
+ raise
416
 
417
  # Generate caption
418
  logger.info("Generating caption")
419
+ try:
420
+ caption = model.generate_caption(image, vocab, max_length=max_length)
421
+ logger.info(f"Generated caption: {caption}")
422
+ return caption
423
+ except Exception as e:
424
+ logger.error(f"Error generating caption: {str(e)}")
425
+ raise