dixisouls commited on
Commit
a0c5c81
·
1 Parent(s): f2ddb15

Initial Commit

Browse files
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /code
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && \
7
+ apt-get install -y --no-install-recommends \
8
+ build-essential \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements first to leverage Docker cache
12
+ COPY requirements.txt .
13
+
14
+ # Install dependencies
15
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
16
+
17
+ # Create necessary directories with correct permissions
18
+ RUN mkdir -p /tmp/uploads && chmod 777 /tmp/uploads
19
+ RUN mkdir -p app/models && chmod 777 app/models
20
+
21
+ # Copy application code
22
+ COPY app ./app
23
+ COPY app.py .
24
+
25
+ # Download NLTK data
26
+ RUN python -c "import nltk; nltk.download('punkt')"
27
+
28
+ # Download model files during build
29
+ RUN python -m app.download_model
30
+
31
+ # Expose port
32
+ EXPOSE 7860
33
+
34
+ # Run the application
35
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,12 +1,80 @@
1
- ---
2
- title: Image Captioning Api
3
- emoji: 😻
4
- colorFrom: indigo
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- short_description: API Endpoint for Image Captioning
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image Captioning API
2
+
3
+ A RESTful API for generating captions from images using a Transformer-based
4
+ model. This service is designed to be deployed on Hugging Face Spaces.
5
+
6
+ ## Features
7
+
8
+ - Upload any image file (jpg, png, etc.)
9
+ - Get AI-generated captions based on image content
10
+ - FastAPI-based REST API with documentation
11
+
12
+ ## API Endpoints
13
+
14
+ - `GET /` - API information and usage
15
+ - `POST /generate` - Upload an image and get a caption
16
+ - `GET /health` - Health check endpoint
17
+ - `GET /docs` - Swagger UI documentation
18
+
19
+ ## How to Use
20
+
21
+ ### API Request Example
22
+
23
+ ```bash
24
+ curl -X POST "https://your-space-name.hf.space/generate" \
25
+ -H "accept: application/json" \
26
+ -H "Content-Type: multipart/form-data" \
27
+ -F "image=@your_image.jpg" \
28
+ -F "max_length=20"
29
+ ```
30
+
31
+ ### API Response Example
32
+
33
+ ```json
34
+ {
35
+ "caption": "a person riding a snowboard down a snow covered slope",
36
+ "image": "base64_encoded_image_data..."
37
+ }
38
+ ```
39
+
40
+ ## Local Development
41
+
42
+ ### Prerequisites
43
+
44
+ - Python 3.9+
45
+ - pip
46
+
47
+ ### Setup
48
+
49
+ 1. Clone the repository
50
+ 2. Install dependencies:
51
+ ```
52
+ pip install -r requirements.txt
53
+ ```
54
+ 3. Run the application:
55
+ ```
56
+ python app.py
57
+ ```
58
+ 4. Visit http://localhost:7860/docs to access the API documentation
59
+
60
+ ## Deployment on Hugging Face Spaces
61
+
62
+ This application is designed to be deployed on
63
+ [Hugging Face Spaces](https://huggingface.co/spaces) using Docker.
64
+
65
+ 1. Create a new Space on Hugging Face
66
+ 2. Select Docker as the SDK
67
+ 3. Upload all files to the repository
68
+ 4. Hugging Face will automatically build and deploy the application
69
+
70
+ ## Technical Details
71
+
72
+ - **Model**: ResNet50 encoder with Transformer decoder
73
+ - **Framework**: PyTorch
74
+ - **API**: FastAPI
75
+ - **Image Processing**: torchvision and PIL
76
+ - **Model Hosting**: Hugging Face Hub
77
+
78
+ ## License
79
+
80
+ MIT
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main application entry point for Image Captioning API
3
+ """
4
+ import os
5
+ import logging
6
+
7
+ # Configure logging
8
+ logging.basicConfig(
9
+ level=logging.INFO,
10
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
11
+ )
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Check if model files exist and download if needed
15
+ def ensure_models_exist():
16
+ model_path = "app/models/image_captioning_model.pth"
17
+ vocab_path = "app/models/vocab.pkl"
18
+
19
+ if not os.path.exists(model_path) or not os.path.exists(vocab_path):
20
+ logger.info("Model files not found. Downloading...")
21
+ from app.download_model import download_models
22
+ download_models()
23
+ else:
24
+ logger.info("Model files found.")
25
+
26
+ if __name__ == "__main__":
27
+ # Ensure model files exist
28
+ ensure_models_exist()
29
+
30
+ # Run the FastAPI application
31
+ import uvicorn
32
+ from app.api import app
33
+
34
+ logger.info("Starting Image Captioning API server...")
35
+ uvicorn.run(app, host="0.0.0.0", port=7860)
app/__init__.py ADDED
File without changes
app/api.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ import shutil
6
+ import uuid
7
+ import logging
8
+ from typing import Dict, Any
9
+ import torch
10
+
11
+ # Import image captioning service
12
+ from app.image_captioning_service import generate_caption
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Use /tmp directory which should be writable
19
+ UPLOAD_DIR = "/tmp/uploads"
20
+
21
+ # Create necessary directories
22
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
23
+ os.makedirs("app/models", exist_ok=True)
24
+
25
+ # Initialize FastAPI app
26
+ app = FastAPI(title="Image Captioning API")
27
+
28
+ # Add CORS middleware
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"],
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ # Get device
38
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ logger.info(f"Using device: {device}")
40
+
41
+ @app.get("/")
42
+ def read_root():
43
+ return {
44
+ "message": "Image Captioning API is running",
45
+ "usage": "POST /generate with an image file to generate a caption",
46
+ "docs": "Visit /docs for API documentation"
47
+ }
48
+
49
+ @app.post("/generate")
50
+ async def generate_image_caption(
51
+ image: UploadFile = File(...),
52
+ max_length: int = Form(20),
53
+ ) -> Dict[str, Any]:
54
+ try:
55
+ # Debug information
56
+ logger.info(f"Received file: {image.filename}, content_type: {image.content_type}")
57
+
58
+ # Input validation with improved error handling
59
+ if image is None:
60
+ raise HTTPException(status_code=400, detail="No image file provided")
61
+
62
+ if not image.content_type:
63
+ # Set a default content type if none provided
64
+ logger.warning("No content type provided, assuming image/jpeg")
65
+ image.content_type = "image/jpeg"
66
+
67
+ if not image.content_type.startswith("image/"):
68
+ raise HTTPException(
69
+ status_code=400, detail=f"Uploaded file must be an image, got {image.content_type}"
70
+ )
71
+
72
+ if not (0 < max_length <= 100):
73
+ raise HTTPException(
74
+ status_code=400, detail="Maximum caption length must be between 1 and 100"
75
+ )
76
+
77
+ # Generate unique ID for this job
78
+ job_id = str(uuid.uuid4())
79
+ short_id = job_id.split("-")[0]
80
+
81
+ # Create directories for this job in /tmp which should be writable
82
+ upload_job_dir = os.path.join(UPLOAD_DIR, job_id)
83
+
84
+ # Create directories with explicit permission setting
85
+ os.makedirs(upload_job_dir, exist_ok=True, mode=0o777)
86
+
87
+ logger.info(f"Created upload directory: {upload_job_dir}")
88
+
89
+ # Determine file extension
90
+ file_ext = os.path.splitext(image.filename)[1] if image.filename else ".jpg"
91
+ if not file_ext:
92
+ file_ext = ".jpg"
93
+
94
+ # Save the uploaded image to /tmp
95
+ image_filename = f"{short_id}{file_ext}"
96
+ image_path = os.path.join(upload_job_dir, image_filename)
97
+
98
+ # Save the file with error handling
99
+ try:
100
+ # Explicitly open with write permissions
101
+ with open(image_path, "wb") as buffer:
102
+ contents = await image.read()
103
+ buffer.write(contents)
104
+
105
+ # Check if file was created and has size
106
+ if not os.path.exists(image_path):
107
+ raise HTTPException(status_code=400, detail=f"Failed to save uploaded file to {image_path}")
108
+
109
+ if os.path.getsize(image_path) == 0:
110
+ raise HTTPException(status_code=400, detail="Uploaded file is empty")
111
+
112
+ logger.info(f"Image saved to {image_path} ({os.path.getsize(image_path)} bytes)")
113
+ except Exception as e:
114
+ logger.error(f"Error saving file: {str(e)}")
115
+ raise HTTPException(status_code=500, detail=f"Error saving uploaded file: {str(e)}")
116
+
117
+ # Define model paths
118
+ model_path = "app/models/image_captioning_model.pth"
119
+ vocabulary_path = "app/models/vocab.pkl"
120
+
121
+ # Check if model files exist
122
+ if not os.path.exists(model_path):
123
+ logger.error(f"Model file not found: {model_path}")
124
+ raise HTTPException(status_code=500, detail=f"Model file not found: {model_path}")
125
+
126
+ if not os.path.exists(vocabulary_path):
127
+ logger.error(f"Vocabulary file not found: {vocabulary_path}")
128
+ raise HTTPException(status_code=500, detail=f"Vocabulary file not found: {vocabulary_path}")
129
+
130
+ # Generate caption
131
+ try:
132
+ caption = generate_caption(
133
+ image_path=image_path,
134
+ model_path=model_path,
135
+ vocab_path=vocabulary_path,
136
+ max_length=max_length,
137
+ device=device
138
+ )
139
+
140
+ logger.info(f"Generated caption: {caption}")
141
+ except Exception as e:
142
+ logger.error(f"Error generating caption: {str(e)}")
143
+ raise HTTPException(status_code=500, detail=f"Error generating caption: {str(e)}")
144
+
145
+ # Read the original image as base64
146
+ try:
147
+ with open(image_path, "rb") as img_file:
148
+ image_base64 = base64.b64encode(img_file.read()).decode("utf-8")
149
+
150
+ logger.info("Successfully encoded image as base64")
151
+ except Exception as e:
152
+ logger.error(f"Error reading image: {str(e)}")
153
+ raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}")
154
+
155
+ # Prepare response with base64 encoded image
156
+ response = {
157
+ "caption": caption,
158
+ "image": image_base64
159
+ }
160
+
161
+ # Clean up
162
+ try:
163
+ shutil.rmtree(upload_job_dir)
164
+ logger.info("Cleaned up temporary directories")
165
+ except Exception as e:
166
+ logger.warning(f"Error cleaning up temporary files: {str(e)}")
167
+
168
+ return response
169
+
170
+ except Exception as e:
171
+ logger.error(f"Error processing image: {str(e)}", exc_info=True)
172
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
173
+
174
+
175
+ @app.get("/health")
176
+ def health_check():
177
+ return {"status": "healthy"}
app/download_model.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from huggingface_hub import hf_hub_download
4
+ import shutil
5
+ import logging
6
+
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ def download_models():
12
+ """Download model files from Hugging Face Hub"""
13
+ logger.info("Downloading model files...")
14
+
15
+ # Create directories if they don't exist
16
+ os.makedirs("app/models", exist_ok=True)
17
+
18
+ try:
19
+ # Download the model and vocabulary from Hugging Face
20
+ logger.info("Downloading model from dixisouls/image-captioning-model...")
21
+ model_path = hf_hub_download(
22
+ repo_id="dixisouls/image-captioning-model",
23
+ filename="image_captioning_model.pth",
24
+ repo_type="model"
25
+ )
26
+
27
+ logger.info("Downloading vocabulary from dixisouls/image-captioning-model...")
28
+ vocab_path = hf_hub_download(
29
+ repo_id="dixisouls/image-captioning-model",
30
+ filename="vocab.pkl",
31
+ repo_type="model"
32
+ )
33
+
34
+ # Copy the downloaded files to the app/models directory
35
+ shutil.copy(model_path, "app/models/image_captioning_model.pth")
36
+ shutil.copy(vocab_path, "app/models/vocab.pkl")
37
+
38
+ logger.info(f"Model downloaded successfully to app/models/image_captioning_model.pth")
39
+ logger.info(f"Vocabulary downloaded successfully to app/models/vocab.pkl")
40
+
41
+ # Create fixed vocabulary file if needed
42
+ try:
43
+ from app.fix_vocab_pickle import fix_vocab_pickle
44
+ fixed_vocab = fix_vocab_pickle("app/models/vocab.pkl", "app/models/vocab_fixed.pkl")
45
+ if fixed_vocab:
46
+ logger.info("Created fixed vocabulary file at app/models/vocab_fixed.pkl")
47
+ except Exception as e:
48
+ logger.warning(f"Could not create fixed vocabulary file: {str(e)}")
49
+
50
+ except Exception as e:
51
+ logger.error(f"Error downloading model files: {e}")
52
+ sys.exit(1)
53
+
54
+ if __name__ == "__main__":
55
+ download_models()
app/fix_vocab_pickle.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to fix the vocabulary pickle file by recreating it with correct module information.
3
+ Run this script if you're still experiencing Vocabulary loading issues.
4
+ """
5
+
6
+ import pickle
7
+ import os
8
+ import sys
9
+ import nltk
10
+ import logging
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Make sure NLTK tokenizer is available
17
+ try:
18
+ nltk.data.find('tokenizers/punkt')
19
+ except LookupError:
20
+ nltk.download('punkt')
21
+
22
+ # Vocabulary class for loading the vocabulary
23
+ class Vocabulary:
24
+ def __init__(self):
25
+ self.word2idx = {}
26
+ self.idx2word = {}
27
+ self.idx = 0
28
+
29
+ def add_word(self, word):
30
+ if word not in self.word2idx:
31
+ self.word2idx[word] = self.idx
32
+ self.idx2word[self.idx] = word
33
+ self.idx += 1
34
+
35
+ def __len__(self):
36
+ return len(self.word2idx)
37
+
38
+ def tokenize(self, text):
39
+ """Tokenize text into a list of tokens"""
40
+ tokens = nltk.tokenize.word_tokenize(str(text).lower())
41
+ return tokens
42
+
43
+ def fix_vocab_pickle(input_path, output_path):
44
+ """
45
+ Load the vocabulary pickle file and create a new one with updated module information.
46
+ """
47
+ try:
48
+ logger.info(f"Attempting to load vocabulary from {input_path}...")
49
+
50
+ # Try first with a very permissive custom unpickler
51
+ class FixerUnpickler(pickle.Unpickler):
52
+ def find_class(self, module, name):
53
+ # For any class named Vocabulary, use our Vocabulary class
54
+ if name == 'Vocabulary':
55
+ return Vocabulary
56
+ # Attempt default behavior, but catch and handle potential errors
57
+ try:
58
+ return super().find_class(module, name)
59
+ except:
60
+ # If we can't find the class in the specified module, try to find an equivalent
61
+ if name == 'Vocabulary':
62
+ return Vocabulary
63
+ # For other classes, we might need more specific handling
64
+ raise
65
+
66
+ # Try to load with our custom unpickler
67
+ with open(input_path, 'rb') as f:
68
+ try:
69
+ vocab = FixerUnpickler(f).load()
70
+ logger.info("Successfully loaded vocabulary!")
71
+ except Exception as e:
72
+ logger.warning(f"Custom unpickler failed: {str(e)}")
73
+
74
+ # If that fails, try raw load and extract data
75
+ f.seek(0) # Reset file pointer
76
+ try:
77
+ raw_data = pickle.load(f)
78
+ logger.info("Loaded raw data, attempting to extract vocabulary...")
79
+
80
+ # Create a new vocabulary
81
+ vocab = Vocabulary()
82
+
83
+ # Try to extract the necessary data
84
+ if hasattr(raw_data, 'word2idx') and hasattr(raw_data, 'idx2word'):
85
+ vocab.word2idx = raw_data.word2idx
86
+ vocab.idx2word = raw_data.idx2word
87
+ vocab.idx = raw_data.idx if hasattr(raw_data, 'idx') else len(vocab.word2idx)
88
+ elif isinstance(raw_data, dict) and 'word2idx' in raw_data and 'idx2word' in raw_data:
89
+ vocab.word2idx = raw_data['word2idx']
90
+ vocab.idx2word = raw_data['idx2word']
91
+ vocab.idx = raw_data.get('idx', len(vocab.word2idx))
92
+ else:
93
+ logger.error("Could not extract vocabulary data from the pickle file.")
94
+ logger.error(f"Raw data type: {type(raw_data)}")
95
+ return None
96
+ except Exception as e:
97
+ logger.error(f"Raw data extraction failed: {str(e)}")
98
+ return None
99
+
100
+ # Save the vocabulary with the correct module information
101
+ logger.info(f"Saving fixed vocabulary to {output_path}...")
102
+ with open(output_path, 'wb') as f:
103
+ pickle.dump(vocab, f, protocol=pickle.HIGHEST_PROTOCOL)
104
+
105
+ logger.info(f"Vocabulary successfully fixed and saved to {output_path}")
106
+ logger.info(f"Vocabulary size: {len(vocab)} words")
107
+ logger.info(f"Sample words: {list(vocab.word2idx.keys())[:5]}")
108
+
109
+ return vocab
110
+
111
+ except Exception as e:
112
+ logger.error(f"An error occurred: {str(e)}")
113
+ return None
114
+
115
+ if __name__ == "__main__":
116
+ # Parse command line arguments
117
+ import argparse
118
+ parser = argparse.ArgumentParser(description='Fix vocabulary pickle file')
119
+ parser.add_argument('--input', type=str, default='app/models/vocab.pkl', help='Path to the input vocabulary pickle file')
120
+ parser.add_argument('--output', type=str, default='app/models/vocab_fixed.pkl', help='Path to save the fixed vocabulary pickle file')
121
+ args = parser.parse_args()
122
+
123
+ # Run the fix function
124
+ vocab = fix_vocab_pickle(args.input, args.output)
125
+
126
+ if vocab is not None:
127
+ logger.info("\nTo use the fixed vocabulary, update your paths to use the new file.")
app/image_captioning_service.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import nltk
6
+ import pickle
7
+ import warnings
8
+ import logging
9
+ warnings.filterwarnings("ignore")
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Make sure NLTK tokenizer is available
16
+ try:
17
+ nltk.data.find('tokenizers/punkt')
18
+ except LookupError:
19
+ nltk.download('punkt')
20
+
21
+ # Vocabulary class for loading the vocabulary
22
+ class Vocabulary:
23
+ def __init__(self):
24
+ self.word2idx = {}
25
+ self.idx2word = {}
26
+ self.idx = 0
27
+
28
+ def add_word(self, word):
29
+ if word not in self.word2idx:
30
+ self.word2idx[word] = self.idx
31
+ self.idx2word[self.idx] = word
32
+ self.idx += 1
33
+
34
+ def __len__(self):
35
+ return len(self.word2idx)
36
+
37
+ def tokenize(self, text):
38
+ """Tokenize text into a list of tokens"""
39
+ tokens = nltk.tokenize.word_tokenize(str(text).lower())
40
+ return tokens
41
+
42
+ @classmethod
43
+ def load(cls, path):
44
+ """Load vocabulary from pickle file"""
45
+ # Try multiple strategies to load the vocabulary
46
+ try:
47
+ # Strategy 1: Use a custom unpickler with more comprehensive handling
48
+ class CustomUnpickler(pickle.Unpickler):
49
+ def find_class(self, module, name):
50
+ # Check for Vocabulary in any module path
51
+ if name == 'Vocabulary':
52
+ # Try to find Vocabulary in different possible modules
53
+ # First in this current module
54
+ return Vocabulary
55
+ # Check for special cases
56
+ if module == '__main__':
57
+ # Look in typical modules where the class might be defined
58
+ if name == 'Vocabulary':
59
+ return Vocabulary
60
+ # Default behavior
61
+ return super().find_class(module, name)
62
+
63
+ with open(path, 'rb') as f:
64
+ return CustomUnpickler(f).load()
65
+ except Exception as e:
66
+ logger.error(f"First loading method failed: {str(e)}")
67
+ try:
68
+ # Strategy 2: Manual recreation of vocabulary object from raw pickle data
69
+ with open(path, 'rb') as f:
70
+ raw_data = pickle.load(f)
71
+ # If it's a dict-like object, we can try to extract the vocabulary data
72
+ if hasattr(raw_data, 'word2idx') and hasattr(raw_data, 'idx2word'):
73
+ # Create a new Vocabulary instance
74
+ vocab = Vocabulary()
75
+ vocab.word2idx = raw_data.word2idx
76
+ vocab.idx2word = raw_data.idx2word
77
+ vocab.idx = raw_data.idx
78
+ return vocab
79
+ else:
80
+ # Create a fresh vocabulary directly from the dictionary data
81
+ vocab = Vocabulary()
82
+ # Try to extract word mappings from whatever structure the pickle has
83
+ if isinstance(raw_data, dict):
84
+ if 'word2idx' in raw_data and 'idx2word' in raw_data:
85
+ vocab.word2idx = raw_data['word2idx']
86
+ vocab.idx2word = raw_data['idx2word']
87
+ vocab.idx = len(vocab.word2idx)
88
+ return vocab
89
+
90
+ raise ValueError("Could not extract vocabulary data from pickle file")
91
+ except Exception as e:
92
+ logger.error(f"Second loading method failed: {str(e)}")
93
+
94
+ # Try to use fix_vocab_pickle as a last resort
95
+ try:
96
+ from app.fix_vocab_pickle import fix_vocab_pickle
97
+ fixed_path = path + "_fixed.pkl"
98
+ vocab = fix_vocab_pickle(path, fixed_path)
99
+ if vocab:
100
+ logger.info(f"Vocabulary fixed and saved to {fixed_path}")
101
+ return vocab
102
+ except Exception as e:
103
+ logger.error(f"Vocabulary fixing failed: {str(e)}")
104
+
105
+ raise RuntimeError(f"All vocabulary loading methods failed. Original error: {str(e)}")
106
+
107
+ # Encoder: Pretrained ResNet
108
+ class EncoderCNN(torch.nn.Module):
109
+ def __init__(self, embed_dim):
110
+ super(EncoderCNN, self).__init__()
111
+ # Load pretrained ResNet
112
+ import torchvision.models as models
113
+ resnet = models.resnet50(pretrained=True)
114
+ # Remove the final FC layer
115
+ modules = list(resnet.children())[:-1]
116
+ self.resnet = torch.nn.Sequential(*modules)
117
+ # Project to embedding dimension
118
+ self.fc = torch.nn.Linear(resnet.fc.in_features, embed_dim)
119
+ self.bn = torch.nn.BatchNorm1d(embed_dim)
120
+ self.dropout = torch.nn.Dropout(0.5)
121
+
122
+ def forward(self, images):
123
+ with torch.no_grad(): # No gradients for pretrained model
124
+ features = self.resnet(images)
125
+ features = features.reshape(features.size(0), -1)
126
+ features = self.fc(features)
127
+ features = self.bn(features)
128
+ features = self.dropout(features)
129
+ return features
130
+
131
+ # Positional Encoding for Transformer
132
+ class PositionalEncoding(torch.nn.Module):
133
+ def __init__(self, d_model, max_len=5000):
134
+ super(PositionalEncoding, self).__init__()
135
+ import math
136
+
137
+ # Create positional encoding
138
+ pe = torch.zeros(max_len, d_model)
139
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
140
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
141
+
142
+ pe[:, 0::2] = torch.sin(position * div_term)
143
+ pe[:, 1::2] = torch.cos(position * div_term)
144
+ pe = pe.unsqueeze(0)
145
+
146
+ # Register buffer (not model parameter)
147
+ self.register_buffer('pe', pe)
148
+
149
+ def forward(self, x):
150
+ x = x + self.pe[:, :x.size(1), :].to(x.device)
151
+ return x
152
+
153
+ # Custom Transformer Decoder
154
+ class TransformerDecoder(torch.nn.Module):
155
+ def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers, dropout=0.1):
156
+ super(TransformerDecoder, self).__init__()
157
+ import math
158
+
159
+ # Embedding layer
160
+ self.embedding = torch.nn.Embedding(vocab_size, embed_dim)
161
+ self.positional_encoding = PositionalEncoding(embed_dim)
162
+
163
+ # Transformer decoder layers
164
+ decoder_layer = torch.nn.TransformerDecoderLayer(
165
+ d_model=embed_dim,
166
+ nhead=num_heads,
167
+ dim_feedforward=ff_dim,
168
+ dropout=dropout,
169
+ batch_first=True
170
+ )
171
+
172
+ self.transformer_decoder = torch.nn.TransformerDecoder(
173
+ decoder_layer,
174
+ num_layers=num_layers
175
+ )
176
+
177
+ # Output layer
178
+ self.fc = torch.nn.Linear(embed_dim, vocab_size)
179
+ self.dropout = torch.nn.Dropout(dropout)
180
+
181
+ def generate_square_subsequent_mask(self, sz):
182
+ # Create mask to prevent attention to future tokens
183
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
184
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
185
+ return mask
186
+
187
+ def forward(self, tgt, memory):
188
+ # Create mask for decoder
189
+ tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
190
+
191
+ # Embed tokens and add positional encoding
192
+ tgt = self.embedding(tgt) * math.sqrt(self.embedding.embedding_dim)
193
+ tgt = self.positional_encoding(tgt)
194
+ tgt = self.dropout(tgt)
195
+
196
+ # Pass through transformer decoder
197
+ output = self.transformer_decoder(
198
+ tgt,
199
+ memory,
200
+ tgt_mask=tgt_mask
201
+ )
202
+
203
+ # Project to vocabulary
204
+ output = self.fc(output)
205
+
206
+ return output
207
+
208
+ # Complete Image Captioning Model
209
+ class ImageCaptioningModel(torch.nn.Module):
210
+ def __init__(self, vocab_size, embed_dim, hidden_dim, num_heads, num_layers):
211
+ super(ImageCaptioningModel, self).__init__()
212
+
213
+ # Image encoder
214
+ self.encoder = EncoderCNN(embed_dim)
215
+
216
+ # Caption decoder
217
+ self.decoder = TransformerDecoder(
218
+ vocab_size=vocab_size,
219
+ embed_dim=embed_dim,
220
+ num_heads=num_heads,
221
+ ff_dim=hidden_dim,
222
+ num_layers=num_layers
223
+ )
224
+
225
+ def forward(self, images, captions):
226
+ # Encode images
227
+ img_features = self.encoder(images)
228
+
229
+ # Reshape for transformer (batch_size, seq_len, embed_dim)
230
+ # In this case, seq_len=1 since we have a single "token" representing the image
231
+ img_features = img_features.unsqueeze(1)
232
+
233
+ # Decode captions (excluding the last token, typically <EOS>)
234
+ outputs = self.decoder(captions[:, :-1], img_features)
235
+
236
+ return outputs
237
+
238
+ def generate_caption(self, image, vocab, max_length=20):
239
+ """Generate a caption for the given image"""
240
+ with torch.no_grad():
241
+ # Encode image
242
+ img_features = self.encoder(image.unsqueeze(0))
243
+ img_features = img_features.unsqueeze(1)
244
+
245
+ # Start with < SOS > token
246
+ current_ids = torch.tensor([[vocab.word2idx['< SOS >']]], dtype=torch.long).to(image.device)
247
+
248
+ # Generate words one by one
249
+ result_caption = []
250
+
251
+ for i in range(max_length):
252
+ # Predict next word
253
+ outputs = self.decoder(current_ids, img_features)
254
+ # Get the most likely next word
255
+ _, predicted = outputs[:, -1, :].max(1)
256
+
257
+ # Add predicted word to the sequence
258
+ result_caption.append(predicted.item())
259
+
260
+ # Break if <EOS>
261
+ if predicted.item() == vocab.word2idx['<EOS>']:
262
+ break
263
+
264
+ # Add to current sequence for next iteration
265
+ current_ids = torch.cat([current_ids, predicted.unsqueeze(0)], dim=1)
266
+
267
+ # Convert word indices to words
268
+ words = [vocab.idx2word[idx] for idx in result_caption]
269
+
270
+ # Remove <EOS> token if present
271
+ if words and words[-1] == '<EOS>':
272
+ words = words[:-1]
273
+
274
+ return ' '.join(words)
275
+
276
+ def load_image(image_path, transform=None):
277
+ """Load and preprocess an image"""
278
+ image = Image.open(image_path).convert('RGB')
279
+
280
+ if transform is None:
281
+ transform = transforms.Compose([
282
+ transforms.Resize((224, 224)),
283
+ transforms.ToTensor(),
284
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
285
+ ])
286
+
287
+ image = transform(image)
288
+ return image
289
+
290
+ def generate_caption(
291
+ image_path,
292
+ model_path,
293
+ vocab_path,
294
+ max_length=20,
295
+ device=None
296
+ ):
297
+ """Generate a caption for an image"""
298
+ # Set device
299
+ if device is None:
300
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
301
+
302
+ logger.info(f"Using device: {device}")
303
+
304
+ # Check if files exist
305
+ if not os.path.exists(image_path):
306
+ raise FileNotFoundError(f"Image not found at {image_path}")
307
+
308
+ if not os.path.exists(model_path):
309
+ raise FileNotFoundError(f"Model not found at {model_path}")
310
+
311
+ if not os.path.exists(vocab_path):
312
+ raise FileNotFoundError(f"Vocabulary not found at {vocab_path}")
313
+
314
+ # Load vocabulary
315
+ logger.info(f"Loading vocabulary from {vocab_path}")
316
+ vocab = Vocabulary.load(vocab_path)
317
+ logger.info(f"Loaded vocabulary with {len(vocab)} words")
318
+
319
+ # Load model
320
+ # Hyperparameters - must match those used during training
321
+ embed_dim = 512
322
+ hidden_dim = 2048
323
+ num_layers = 6
324
+ num_heads = 8
325
+
326
+ # Initialize model
327
+ logger.info("Initializing model")
328
+ model = ImageCaptioningModel(
329
+ vocab_size=len(vocab),
330
+ embed_dim=embed_dim,
331
+ hidden_dim=hidden_dim,
332
+ num_heads=num_heads,
333
+ num_layers=num_layers
334
+ ).to(device)
335
+
336
+ # Load model weights
337
+ logger.info(f"Loading model weights from {model_path}")
338
+ checkpoint = torch.load(model_path, map_location=device)
339
+ model.load_state_dict(checkpoint['model_state_dict'])
340
+ model.eval()
341
+
342
+ # Load and process image
343
+ logger.info(f"Loading and processing image from {image_path}")
344
+ image = load_image(image_path)
345
+ image = image.to(device)
346
+
347
+ # Generate caption
348
+ logger.info("Generating caption")
349
+ caption = model.generate_caption(image, vocab, max_length=max_length)
350
+ logger.info(f"Generated caption: {caption}")
351
+
352
+ return caption
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.103.1
2
+ uvicorn==0.23.2
3
+ python-multipart==0.0.6
4
+ pillow==10.0.0
5
+ torch==2.0.1
6
+ torchvision==0.15.2
7
+ nltk==3.8.1
8
+ huggingface-hub==0.16.4
9
+ numpy==1.24.3
10
+ aiofiles==23.1.0
11
+ python-dotenv==1.0.0