openlipsync / scripts /tests /fastapi_modal.py
miguelamendez's picture
Initial upload of directory
75da08b verified
# modal_app.py
import modal
import time
import uuid
from typing import Dict, Any
from fastapi import FastAPI, HTTPException
import asyncio
# Initialize Modal app
app = modal.App("model-inference-app")
image = modal.Image.debian_slim().pip_install("torch", "fastapi", "uvicorn")
# Global storage for job tracking (in production, use Redis or database)
job_storage: Dict[str, Dict[str, Any]] = {}
# GPU app for model inference
@app.function(
image=image,
gpu="A10G",
timeout=300, # 5 minutes timeout
retries=1
)
def inference_worker(input_data: Dict[str, Any], job_id: str) -> Dict[str, Any]:
"""
GPU-based inference worker
"""
import torch
import time
# Simulate model loading (replace with your actual model)
print(f"Loading model for job {job_id}")
# Simulate some computation time
time.sleep(2)
# Your actual model inference logic here
# For this example, we'll simulate a simple operation
result = {
"input": input_data,
"processed": True,
"timestamp": time.time(),
"job_id": job_id,
"result": [x * 2 for x in input_data.get("values", [])] # Example processing
}
# Store result
job_storage[job_id] = {
"status": "completed",
"result": result,
"completed_at": time.time()
}
return result
# FastAPI app for REST endpoints
@app.function(image=image)
@modal.asgi_app()
def api():
fastapi_app = FastAPI()
@fastapi_app.post("/inference")
async def start_inference(input_data: Dict[str, Any]):
"""
Trigger inference job and return job ID
"""
job_id = str(uuid.uuid4())
# Store job in progress
job_storage[job_id] = {
"status": "processing",
"input": input_data,
"started_at": time.time()
}
# Start inference in background
try:
# This will run asynchronously on GPU
inference_worker.spawn(input_data, job_id)
return {
"job_id": job_id,
"status": "processing",
"message": "Inference started"
}
except Exception as e:
job_storage[job_id] = {
"status": "error",
"error": str(e),
"completed_at": time.time()
}
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
@fastapi_app.get("/inference/{job_id}")
async def get_inference_status(job_id: str):
"""
Get status and results of inference job
"""
if job_id not in job_storage:
raise HTTPException(status_code=404, detail="Job not found")
job_info = job_storage[job_id]
status = job_info["status"]
if status == "completed":
return {
"job_id": job_id,
"status": "completed",
"result": job_info["result"]
}
elif status == "processing":
return {
"job_id": job_id,
"status": "processing",
"message": "Inference is still running"
}
else:
return {
"job_id": job_id,
"status": "error",
"error": job_info["error"]
}
@fastapi_app.get("/health")
async def health_check():
return {"status": "healthy"}
return fastapi_app
# For local testing
if __name__ == "__main__":
# This allows you to test locally
import uvicorn
# Create a simple test
@app.local_entrypoint()
def main():
# Test the inference function directly
result = inference_worker.remote({"values": [1, 2, 3, 4]}, "test-job")
print("Direct inference result:", result)
# Test the API endpoints (you'd normally test this via HTTP calls)
print("API endpoints available:")
print("POST /inference - Start inference job")
print("GET /inference/{job_id} - Check job status")
print("GET /health - Health check")