dixisouls commited on
Commit
9f11f00
·
1 Parent(s): ce43c34

Vocabulary class error

Browse files
Files changed (3) hide show
  1. app.py +54 -33
  2. app/api.py +9 -1
  3. app/image_captioning_service.py +36 -1
app.py CHANGED
@@ -1,41 +1,59 @@
1
- def setup_nltk():
2
- """Set up NLTK data directory and ensure punkt tokenizer is available"""
3
- logger.info("Setting up NLTK...")
4
-
5
- # Create potential NLTK data directories with proper permissions
6
- nltk_dirs = [
7
- os.path.expanduser('~/.nltk_data'),
8
- './nltk_data',
9
- '/usr/local/share/nltk_data'
10
- ]
11
-
12
- for directory in nltk_dirs:
13
- try:
14
- os.makedirs(directory, exist_ok=True)
15
- logger.info(f"Created NLTK data directory: {directory}")
16
- except Exception as e:
17
- logger.warning(f"Could not create NLTK directory {directory}: {e}")
18
-
19
- # Try to find punkt tokenizer
20
  try:
21
- nltk.data.find('tokenizers/punkt')
22
- logger.info("NLTK punkt tokenizer found!")
23
- return
24
- except LookupError:
25
- # Not found, try to download to different locations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  for directory in nltk_dirs:
27
  try:
28
- logger.info(f"Attempting to download punkt tokenizer to {directory}")
29
- nltk.download('punkt', download_dir=directory)
30
- logger.info(f"Successfully downloaded punkt tokenizer to {directory}")
31
- return
32
  except Exception as e:
33
- logger.warning(f"Failed to download punkt to {directory}: {e}")
34
 
35
- # If we get here, we couldn't download punkt anywhere
36
- logger.error("Could not download NLTK punkt tokenizer to any location")
37
- logger.error("The application may not function correctly")
38
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  Main application entry point for Image Captioning API
40
  """
41
  import os
@@ -106,6 +124,9 @@ if __name__ == "__main__":
106
  # Setup NLTK
107
  setup_nltk()
108
 
 
 
 
109
  # Ensure model files exist
110
  ensure_models_exist()
111
 
 
1
+ def register_vocabulary_in_main():
2
+ """Register the Vocabulary class in __main__ to help with unpickling"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  try:
4
+ logger.info("Registering Vocabulary class in __main__ module")
5
+ import sys
6
+ import __main__
7
+ from app.image_captioning_service import Vocabulary, ImageCaptioningModel, EncoderCNN, TransformerDecoder, PositionalEncoding
8
+
9
+ # Register classes in main module
10
+ setattr(__main__, 'Vocabulary', Vocabulary)
11
+ setattr(__main__, 'ImageCaptioningModel', ImageCaptioningModel)
12
+ setattr(__main__, 'EncoderCNN', EncoderCNN)
13
+ setattr(__main__, 'TransformerDecoder', TransformerDecoder)
14
+ setattr(__main__, 'PositionalEncoding', PositionalEncoding)
15
+
16
+ logger.info("Successfully registered classes in __main__")
17
+ except Exception as e:
18
+ logger.warning(f"Could not register classes in __main__: {e}")
19
+ def setup_nltk():
20
+ """Set up NLTK data directory and ensure punkt tokenizer is available"""
21
+ logger.info("Setting up NLTK...")
22
+
23
+ # Create potential NLTK data directories with proper permissions
24
+ nltk_dirs = [
25
+ os.path.expanduser('~/.nltk_data'),
26
+ './nltk_data',
27
+ '/usr/local/share/nltk_data'
28
+ ]
29
+
30
  for directory in nltk_dirs:
31
  try:
32
+ os.makedirs(directory, exist_ok=True)
33
+ logger.info(f"Created NLTK data directory: {directory}")
 
 
34
  except Exception as e:
35
+ logger.warning(f"Could not create NLTK directory {directory}: {e}")
36
 
37
+ # Try to find punkt tokenizer
38
+ try:
39
+ nltk.data.find('tokenizers/punkt')
40
+ logger.info("NLTK punkt tokenizer found!")
41
+ return
42
+ except LookupError:
43
+ # Not found, try to download to different locations
44
+ for directory in nltk_dirs:
45
+ try:
46
+ logger.info(f"Attempting to download punkt tokenizer to {directory}")
47
+ nltk.download('punkt', download_dir=directory)
48
+ logger.info(f"Successfully downloaded punkt tokenizer to {directory}")
49
+ return
50
+ except Exception as e:
51
+ logger.warning(f"Failed to download punkt to {directory}: {e}")
52
+
53
+ # If we get here, we couldn't download punkt anywhere
54
+ logger.error("Could not download NLTK punkt tokenizer to any location")
55
+ logger.error("The application may not function correctly")
56
+ """
57
  Main application entry point for Image Captioning API
58
  """
59
  import os
 
124
  # Setup NLTK
125
  setup_nltk()
126
 
127
+ # Register Vocabulary in main module
128
+ register_vocabulary_in_main()
129
+
130
  # Ensure model files exist
131
  ensure_models_exist()
132
 
app/api.py CHANGED
@@ -9,7 +9,15 @@ from typing import Dict, Any
9
  import torch
10
 
11
  # Import image captioning service
12
- from app.image_captioning_service import generate_caption
 
 
 
 
 
 
 
 
13
 
14
  # Configure logging
15
  logging.basicConfig(level=logging.INFO)
 
9
  import torch
10
 
11
  # Import image captioning service
12
+ from app.image_captioning_service import generate_caption, Vocabulary, ImageCaptioningModel, EncoderCNN, TransformerDecoder, PositionalEncoding
13
+
14
+ # Register these classes in the main module to help with unpickling
15
+ import __main__
16
+ setattr(__main__, 'Vocabulary', Vocabulary)
17
+ setattr(__main__, 'ImageCaptioningModel', ImageCaptioningModel)
18
+ setattr(__main__, 'EncoderCNN', EncoderCNN)
19
+ setattr(__main__, 'TransformerDecoder', TransformerDecoder)
20
+ setattr(__main__, 'PositionalEncoding', PositionalEncoding)
21
 
22
  # Configure logging
23
  logging.basicConfig(level=logging.INFO)
app/image_captioning_service.py CHANGED
@@ -325,6 +325,8 @@ class ImageCaptioningModel(torch.nn.Module):
325
 
326
  return ' '.join(words)
327
 
 
 
328
  def load_image(image_path, transform=None):
329
  """Load and preprocess an image"""
330
  image = Image.open(image_path).convert('RGB')
@@ -396,7 +398,40 @@ def generate_caption(
396
  # Load model weights
397
  logger.info(f"Loading model weights from {model_path}")
398
  try:
399
- checkpoint = torch.load(model_path, map_location=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  model.load_state_dict(checkpoint['model_state_dict'])
401
  model.eval()
402
  logger.info("Model loaded successfully")
 
325
 
326
  return ' '.join(words)
327
 
328
+
329
+
330
  def load_image(image_path, transform=None):
331
  """Load and preprocess an image"""
332
  image = Image.open(image_path).convert('RGB')
 
398
  # Load model weights
399
  logger.info(f"Loading model weights from {model_path}")
400
  try:
401
+ # First try our custom loader
402
+ try:
403
+ logger.info("Trying custom model loader...")
404
+ # Replace this with Python's built-in pickle that we can customize
405
+ # Define a custom unpickler
406
+ class CustomUnpickler(pickle.Unpickler):
407
+ def find_class(self, module, name):
408
+ # If it's looking for the Vocabulary class in __main__
409
+ if name == 'Vocabulary':
410
+ # Return our current Vocabulary class
411
+ return Vocabulary
412
+ if module == '__main__':
413
+ if name == 'ImageCaptioningModel':
414
+ return ImageCaptioningModel
415
+ if name == 'EncoderCNN':
416
+ return EncoderCNN
417
+ if name == 'TransformerDecoder':
418
+ return TransformerDecoder
419
+ if name == 'PositionalEncoding':
420
+ return PositionalEncoding
421
+ # Use the normal behavior for everything else
422
+ return super().find_class(module, name)
423
+
424
+ # Use a custom loading approach
425
+ with open(model_path, 'rb') as f:
426
+ checkpoint = CustomUnpickler(f).load()
427
+
428
+ logger.info("Successfully loaded model using custom unpickler")
429
+ except Exception as e:
430
+ logger.warning(f"Custom loader failed: {str(e)}")
431
+ logger.info("Falling back to standard torch.load...")
432
+ # Fall back to standard loader
433
+ checkpoint = torch.load(model_path, map_location=device)
434
+
435
  model.load_state_dict(checkpoint['model_state_dict'])
436
  model.eval()
437
  logger.info("Model loaded successfully")