import os import base64 from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware import shutil import uuid import logging from typing import Dict, Any import torch # Import image captioning service from app.image_captioning_service import generate_caption, Vocabulary, ImageCaptioningModel, EncoderCNN, TransformerDecoder, PositionalEncoding # Register these classes in the main module to help with unpickling import __main__ setattr(__main__, 'Vocabulary', Vocabulary) setattr(__main__, 'ImageCaptioningModel', ImageCaptioningModel) setattr(__main__, 'EncoderCNN', EncoderCNN) setattr(__main__, 'TransformerDecoder', TransformerDecoder) setattr(__main__, 'PositionalEncoding', PositionalEncoding) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Use /tmp directory which should be writable UPLOAD_DIR = "/tmp/uploads" # Create necessary directories os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs("app/models", exist_ok=True) # Initialize FastAPI app app = FastAPI(title="Image Captioning API") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Get device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") @app.get("/") def read_root(): return { "message": "Image Captioning API is running", "usage": "POST /generate with an image file to generate a caption", "docs": "Visit /docs for API documentation" } @app.post("/generate") async def generate_image_caption( image: UploadFile = File(...), max_length: int = Form(20), ) -> Dict[str, Any]: try: # Debug information logger.info(f"Received file: {image.filename}, content_type: {image.content_type}") # Input validation with improved error handling if image is None: raise HTTPException(status_code=400, detail="No image file provided") if not image.content_type: # Set a default content type if none provided logger.warning("No content type provided, assuming image/jpeg") image.content_type = "image/jpeg" if not image.content_type.startswith("image/"): raise HTTPException( status_code=400, detail=f"Uploaded file must be an image, got {image.content_type}" ) if not (0 < max_length <= 100): raise HTTPException( status_code=400, detail="Maximum caption length must be between 1 and 100" ) # Generate unique ID for this job job_id = str(uuid.uuid4()) short_id = job_id.split("-")[0] # Create directories for this job in /tmp which should be writable upload_job_dir = os.path.join(UPLOAD_DIR, job_id) # Create directories with explicit permission setting os.makedirs(upload_job_dir, exist_ok=True, mode=0o777) logger.info(f"Created upload directory: {upload_job_dir}") # Determine file extension file_ext = os.path.splitext(image.filename)[1] if image.filename else ".jpg" if not file_ext: file_ext = ".jpg" # Save the uploaded image to /tmp image_filename = f"{short_id}{file_ext}" image_path = os.path.join(upload_job_dir, image_filename) # Save the file with error handling try: # Explicitly open with write permissions with open(image_path, "wb") as buffer: contents = await image.read() buffer.write(contents) # Check if file was created and has size if not os.path.exists(image_path): raise HTTPException(status_code=400, detail=f"Failed to save uploaded file to {image_path}") if os.path.getsize(image_path) == 0: raise HTTPException(status_code=400, detail="Uploaded file is empty") logger.info(f"Image saved to {image_path} ({os.path.getsize(image_path)} bytes)") except Exception as e: logger.error(f"Error saving file: {str(e)}") raise HTTPException(status_code=500, detail=f"Error saving uploaded file: {str(e)}") # Define model paths model_path = "app/models/image_captioning_model.pth" vocabulary_path = "app/models/vocab.pkl" # Check if model files exist if not os.path.exists(model_path): logger.error(f"Model file not found: {model_path}") raise HTTPException(status_code=500, detail=f"Model file not found: {model_path}") if not os.path.exists(vocabulary_path): logger.error(f"Vocabulary file not found: {vocabulary_path}") raise HTTPException(status_code=500, detail=f"Vocabulary file not found: {vocabulary_path}") # Generate caption try: caption = generate_caption( image_path=image_path, model_path=model_path, vocab_path=vocabulary_path, max_length=max_length, device=device ) logger.info(f"Generated caption: {caption}") except Exception as e: logger.error(f"Error generating caption: {str(e)}") raise HTTPException(status_code=500, detail=f"Error generating caption: {str(e)}") # Read the original image as base64 try: with open(image_path, "rb") as img_file: image_base64 = base64.b64encode(img_file.read()).decode("utf-8") logger.info("Successfully encoded image as base64") except Exception as e: logger.error(f"Error reading image: {str(e)}") raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") # Prepare response with base64 encoded image response = { "caption": caption, "image": image_base64 } # Clean up try: shutil.rmtree(upload_job_dir) logger.info("Cleaned up temporary directories") except Exception as e: logger.warning(f"Error cleaning up temporary files: {str(e)}") return response except Exception as e: logger.error(f"Error processing image: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") @app.get("/health") def health_check(): return {"status": "healthy"}