Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import shutil | |
| import uuid | |
| import logging | |
| from typing import Dict, Any | |
| import torch | |
| # Import image captioning service | |
| from app.image_captioning_service import generate_caption, Vocabulary, ImageCaptioningModel, EncoderCNN, TransformerDecoder, PositionalEncoding | |
| # Register these classes in the main module to help with unpickling | |
| import __main__ | |
| setattr(__main__, 'Vocabulary', Vocabulary) | |
| setattr(__main__, 'ImageCaptioningModel', ImageCaptioningModel) | |
| setattr(__main__, 'EncoderCNN', EncoderCNN) | |
| setattr(__main__, 'TransformerDecoder', TransformerDecoder) | |
| setattr(__main__, 'PositionalEncoding', PositionalEncoding) | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Use /tmp directory which should be writable | |
| UPLOAD_DIR = "/tmp/uploads" | |
| # Create necessary directories | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| os.makedirs("app/models", exist_ok=True) | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Image Captioning API") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Get device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| def read_root(): | |
| return { | |
| "message": "Image Captioning API is running", | |
| "usage": "POST /generate with an image file to generate a caption", | |
| "docs": "Visit /docs for API documentation" | |
| } | |
| async def generate_image_caption( | |
| image: UploadFile = File(...), | |
| max_length: int = Form(20), | |
| ) -> Dict[str, Any]: | |
| try: | |
| # Debug information | |
| logger.info(f"Received file: {image.filename}, content_type: {image.content_type}") | |
| # Input validation with improved error handling | |
| if image is None: | |
| raise HTTPException(status_code=400, detail="No image file provided") | |
| if not image.content_type: | |
| # Set a default content type if none provided | |
| logger.warning("No content type provided, assuming image/jpeg") | |
| image.content_type = "image/jpeg" | |
| if not image.content_type.startswith("image/"): | |
| raise HTTPException( | |
| status_code=400, detail=f"Uploaded file must be an image, got {image.content_type}" | |
| ) | |
| if not (0 < max_length <= 100): | |
| raise HTTPException( | |
| status_code=400, detail="Maximum caption length must be between 1 and 100" | |
| ) | |
| # Generate unique ID for this job | |
| job_id = str(uuid.uuid4()) | |
| short_id = job_id.split("-")[0] | |
| # Create directories for this job in /tmp which should be writable | |
| upload_job_dir = os.path.join(UPLOAD_DIR, job_id) | |
| # Create directories with explicit permission setting | |
| os.makedirs(upload_job_dir, exist_ok=True, mode=0o777) | |
| logger.info(f"Created upload directory: {upload_job_dir}") | |
| # Determine file extension | |
| file_ext = os.path.splitext(image.filename)[1] if image.filename else ".jpg" | |
| if not file_ext: | |
| file_ext = ".jpg" | |
| # Save the uploaded image to /tmp | |
| image_filename = f"{short_id}{file_ext}" | |
| image_path = os.path.join(upload_job_dir, image_filename) | |
| # Save the file with error handling | |
| try: | |
| # Explicitly open with write permissions | |
| with open(image_path, "wb") as buffer: | |
| contents = await image.read() | |
| buffer.write(contents) | |
| # Check if file was created and has size | |
| if not os.path.exists(image_path): | |
| raise HTTPException(status_code=400, detail=f"Failed to save uploaded file to {image_path}") | |
| if os.path.getsize(image_path) == 0: | |
| raise HTTPException(status_code=400, detail="Uploaded file is empty") | |
| logger.info(f"Image saved to {image_path} ({os.path.getsize(image_path)} bytes)") | |
| except Exception as e: | |
| logger.error(f"Error saving file: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error saving uploaded file: {str(e)}") | |
| # Define model paths | |
| model_path = "app/models/image_captioning_model.pth" | |
| vocabulary_path = "app/models/vocab.pkl" | |
| # Check if model files exist | |
| if not os.path.exists(model_path): | |
| logger.error(f"Model file not found: {model_path}") | |
| raise HTTPException(status_code=500, detail=f"Model file not found: {model_path}") | |
| if not os.path.exists(vocabulary_path): | |
| logger.error(f"Vocabulary file not found: {vocabulary_path}") | |
| raise HTTPException(status_code=500, detail=f"Vocabulary file not found: {vocabulary_path}") | |
| # Generate caption | |
| try: | |
| caption = generate_caption( | |
| image_path=image_path, | |
| model_path=model_path, | |
| vocab_path=vocabulary_path, | |
| max_length=max_length, | |
| device=device | |
| ) | |
| logger.info(f"Generated caption: {caption}") | |
| except Exception as e: | |
| logger.error(f"Error generating caption: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error generating caption: {str(e)}") | |
| # Read the original image as base64 | |
| try: | |
| with open(image_path, "rb") as img_file: | |
| image_base64 = base64.b64encode(img_file.read()).decode("utf-8") | |
| logger.info("Successfully encoded image as base64") | |
| except Exception as e: | |
| logger.error(f"Error reading image: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") | |
| # Prepare response with base64 encoded image | |
| response = { | |
| "caption": caption, | |
| "image": image_base64 | |
| } | |
| # Clean up | |
| try: | |
| shutil.rmtree(upload_job_dir) | |
| logger.info("Cleaned up temporary directories") | |
| except Exception as e: | |
| logger.warning(f"Error cleaning up temporary files: {str(e)}") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error processing image: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
| def health_check(): | |
| return {"status": "healthy"} |