Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import shutil | |
| import uuid | |
| import logging | |
| from typing import Dict, List, Any | |
| import json | |
| # Import scene graph service | |
| from app.scene_graph_service import process_image | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Use /tmp directory which should be writable | |
| UPLOAD_DIR = "/tmp/uploads" | |
| OUTPUT_DIR = "/tmp/outputs" | |
| # Create necessary directories | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| os.makedirs("app/models", exist_ok=True) | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Scene Graph Generation API") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def read_root(): | |
| return { | |
| "message": "Scene Graph Generation API is running", | |
| "usage": "POST /generate with an image file to generate a scene graph", | |
| "docs": "Visit /docs for API documentation" | |
| } | |
| async def generate_scene_graph( | |
| image: UploadFile = File(...), | |
| confidence_threshold: float = Form(0.5), | |
| use_fixed_boxes: bool = Form(False), | |
| ) -> Dict[str, Any]: | |
| try: | |
| # Debug information | |
| logger.info(f"Received file: {image.filename}, content_type: {image.content_type}") | |
| # Input validation with improved error handling | |
| if image is None: | |
| raise HTTPException(status_code=400, detail="No image file provided") | |
| if not image.content_type: | |
| # Set a default content type if none provided | |
| logger.warning("No content type provided, assuming image/jpeg") | |
| image.content_type = "image/jpeg" | |
| if not image.content_type.startswith("image/"): | |
| raise HTTPException( | |
| status_code=400, detail=f"Uploaded file must be an image, got {image.content_type}" | |
| ) | |
| if not (0 <= confidence_threshold <= 1): | |
| raise HTTPException( | |
| status_code=400, detail="Confidence threshold must be between 0 and 1" | |
| ) | |
| # Generate unique ID for this job | |
| job_id = str(uuid.uuid4()) | |
| short_id = job_id.split("-")[0] | |
| # Create directories for this job in /tmp which should be writable | |
| upload_job_dir = os.path.join(UPLOAD_DIR, job_id) | |
| output_job_dir = os.path.join(OUTPUT_DIR, job_id) | |
| # Create directories with explicit permission setting | |
| os.makedirs(upload_job_dir, exist_ok=True, mode=0o777) | |
| os.makedirs(output_job_dir, exist_ok=True, mode=0o777) | |
| logger.info(f"Created upload directory: {upload_job_dir}") | |
| logger.info(f"Created output directory: {output_job_dir}") | |
| # Determine file extension | |
| file_ext = os.path.splitext(image.filename)[1] if image.filename else ".jpg" | |
| if not file_ext: | |
| file_ext = ".jpg" | |
| # Save the uploaded image to /tmp | |
| image_filename = f"{short_id}{file_ext}" | |
| image_path = os.path.join(upload_job_dir, image_filename) | |
| # Save the file with error handling | |
| try: | |
| # Explicitly open with write permissions | |
| with open(image_path, "wb") as buffer: | |
| contents = await image.read() | |
| buffer.write(contents) | |
| # Check if file was created and has size | |
| if not os.path.exists(image_path): | |
| raise HTTPException(status_code=400, detail=f"Failed to save uploaded file to {image_path}") | |
| if os.path.getsize(image_path) == 0: | |
| raise HTTPException(status_code=400, detail="Uploaded file is empty") | |
| logger.info(f"Image saved to {image_path} ({os.path.getsize(image_path)} bytes)") | |
| except Exception as e: | |
| logger.error(f"Error saving file: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error saving uploaded file: {str(e)}") | |
| # Define model paths | |
| model_path = "app/models/model.pth" | |
| vocabulary_path = "app/models/vocabulary.json" | |
| # Check if model files exist | |
| if not os.path.exists(model_path): | |
| logger.error(f"Model file not found: {model_path}") | |
| raise HTTPException(status_code=500, detail=f"Model file not found: {model_path}") | |
| if not os.path.exists(vocabulary_path): | |
| logger.error(f"Vocabulary file not found: {vocabulary_path}") | |
| raise HTTPException(status_code=500, detail=f"Vocabulary file not found: {vocabulary_path}") | |
| # Process the image | |
| objects, relationships, annotated_image_path, graph_path = process_image( | |
| image_path=image_path, | |
| model_path=model_path, | |
| vocabulary_path=vocabulary_path, | |
| confidence_threshold=confidence_threshold, | |
| use_fixed_boxes=use_fixed_boxes, | |
| output_dir=output_job_dir, | |
| base_filename=short_id, | |
| ) | |
| logger.info(f"Processing complete. Annotated image: {annotated_image_path}, Graph: {graph_path}") | |
| # Verify output files exist | |
| if not os.path.exists(annotated_image_path): | |
| logger.error(f"Annotated image not generated: {annotated_image_path}") | |
| raise HTTPException(status_code=500, detail="Failed to generate annotated image") | |
| if not os.path.exists(graph_path): | |
| logger.error(f"Graph image not generated: {graph_path}") | |
| raise HTTPException(status_code=500, detail="Failed to generate graph image") | |
| # Read the generated images as base64 | |
| try: | |
| with open(annotated_image_path, "rb") as img_file: | |
| annotated_image_base64 = base64.b64encode(img_file.read()).decode("utf-8") | |
| with open(graph_path, "rb") as img_file: | |
| graph_image_base64 = base64.b64encode(img_file.read()).decode("utf-8") | |
| logger.info("Successfully encoded images as base64") | |
| except Exception as e: | |
| logger.error(f"Error reading output images: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error reading output images: {str(e)}") | |
| # Prepare response with base64 encoded images | |
| response = { | |
| "objects": objects, | |
| "relationships": relationships, | |
| "annotated_image": annotated_image_base64, | |
| "graph_image": graph_image_base64 | |
| } | |
| # Clean up | |
| try: | |
| shutil.rmtree(upload_job_dir) | |
| shutil.rmtree(output_job_dir) | |
| logger.info("Cleaned up temporary directories") | |
| except Exception as e: | |
| logger.warning(f"Error cleaning up temporary files: {str(e)}") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error processing image: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
| def health_check(): | |
| return {"status": "healthy"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |