# modal_app.py import modal import time import uuid from typing import Dict, Any from fastapi import FastAPI, HTTPException import asyncio # Initialize Modal app app = modal.App("model-inference-app") image = modal.Image.debian_slim().pip_install("torch", "fastapi", "uvicorn") # Global storage for job tracking (in production, use Redis or database) job_storage: Dict[str, Dict[str, Any]] = {} # GPU app for model inference @app.function( image=image, gpu="A10G", timeout=300, # 5 minutes timeout retries=1 ) def inference_worker(input_data: Dict[str, Any], job_id: str) -> Dict[str, Any]: """ GPU-based inference worker """ import torch import time # Simulate model loading (replace with your actual model) print(f"Loading model for job {job_id}") # Simulate some computation time time.sleep(2) # Your actual model inference logic here # For this example, we'll simulate a simple operation result = { "input": input_data, "processed": True, "timestamp": time.time(), "job_id": job_id, "result": [x * 2 for x in input_data.get("values", [])] # Example processing } # Store result job_storage[job_id] = { "status": "completed", "result": result, "completed_at": time.time() } return result # FastAPI app for REST endpoints @app.function(image=image) @modal.asgi_app() def api(): fastapi_app = FastAPI() @fastapi_app.post("/inference") async def start_inference(input_data: Dict[str, Any]): """ Trigger inference job and return job ID """ job_id = str(uuid.uuid4()) # Store job in progress job_storage[job_id] = { "status": "processing", "input": input_data, "started_at": time.time() } # Start inference in background try: # This will run asynchronously on GPU inference_worker.spawn(input_data, job_id) return { "job_id": job_id, "status": "processing", "message": "Inference started" } except Exception as e: job_storage[job_id] = { "status": "error", "error": str(e), "completed_at": time.time() } raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") @fastapi_app.get("/inference/{job_id}") async def get_inference_status(job_id: str): """ Get status and results of inference job """ if job_id not in job_storage: raise HTTPException(status_code=404, detail="Job not found") job_info = job_storage[job_id] status = job_info["status"] if status == "completed": return { "job_id": job_id, "status": "completed", "result": job_info["result"] } elif status == "processing": return { "job_id": job_id, "status": "processing", "message": "Inference is still running" } else: return { "job_id": job_id, "status": "error", "error": job_info["error"] } @fastapi_app.get("/health") async def health_check(): return {"status": "healthy"} return fastapi_app # For local testing if __name__ == "__main__": # This allows you to test locally import uvicorn # Create a simple test @app.local_entrypoint() def main(): # Test the inference function directly result = inference_worker.remote({"values": [1, 2, 3, 4]}, "test-job") print("Direct inference result:", result) # Test the API endpoints (you'd normally test this via HTTP calls) print("API endpoints available:") print("POST /inference - Start inference job") print("GET /inference/{job_id} - Check job status") print("GET /health - Health check")