|
|
|
|
|
import modal |
|
|
import time |
|
|
import uuid |
|
|
from typing import Dict, Any |
|
|
from fastapi import FastAPI, HTTPException |
|
|
import asyncio |
|
|
|
|
|
|
|
|
app = modal.App("model-inference-app") |
|
|
image = modal.Image.debian_slim().pip_install("torch", "fastapi", "uvicorn") |
|
|
|
|
|
|
|
|
job_storage: Dict[str, Dict[str, Any]] = {} |
|
|
|
|
|
|
|
|
@app.function( |
|
|
image=image, |
|
|
gpu="A10G", |
|
|
timeout=300, |
|
|
retries=1 |
|
|
) |
|
|
def inference_worker(input_data: Dict[str, Any], job_id: str) -> Dict[str, Any]: |
|
|
""" |
|
|
GPU-based inference worker |
|
|
""" |
|
|
import torch |
|
|
import time |
|
|
|
|
|
|
|
|
print(f"Loading model for job {job_id}") |
|
|
|
|
|
|
|
|
time.sleep(2) |
|
|
|
|
|
|
|
|
|
|
|
result = { |
|
|
"input": input_data, |
|
|
"processed": True, |
|
|
"timestamp": time.time(), |
|
|
"job_id": job_id, |
|
|
"result": [x * 2 for x in input_data.get("values", [])] |
|
|
} |
|
|
|
|
|
|
|
|
job_storage[job_id] = { |
|
|
"status": "completed", |
|
|
"result": result, |
|
|
"completed_at": time.time() |
|
|
} |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
@app.function(image=image) |
|
|
@modal.asgi_app() |
|
|
def api(): |
|
|
fastapi_app = FastAPI() |
|
|
|
|
|
@fastapi_app.post("/inference") |
|
|
async def start_inference(input_data: Dict[str, Any]): |
|
|
""" |
|
|
Trigger inference job and return job ID |
|
|
""" |
|
|
job_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
job_storage[job_id] = { |
|
|
"status": "processing", |
|
|
"input": input_data, |
|
|
"started_at": time.time() |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
inference_worker.spawn(input_data, job_id) |
|
|
|
|
|
return { |
|
|
"job_id": job_id, |
|
|
"status": "processing", |
|
|
"message": "Inference started" |
|
|
} |
|
|
except Exception as e: |
|
|
job_storage[job_id] = { |
|
|
"status": "error", |
|
|
"error": str(e), |
|
|
"completed_at": time.time() |
|
|
} |
|
|
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") |
|
|
|
|
|
@fastapi_app.get("/inference/{job_id}") |
|
|
async def get_inference_status(job_id: str): |
|
|
""" |
|
|
Get status and results of inference job |
|
|
""" |
|
|
if job_id not in job_storage: |
|
|
raise HTTPException(status_code=404, detail="Job not found") |
|
|
|
|
|
job_info = job_storage[job_id] |
|
|
status = job_info["status"] |
|
|
|
|
|
if status == "completed": |
|
|
return { |
|
|
"job_id": job_id, |
|
|
"status": "completed", |
|
|
"result": job_info["result"] |
|
|
} |
|
|
elif status == "processing": |
|
|
return { |
|
|
"job_id": job_id, |
|
|
"status": "processing", |
|
|
"message": "Inference is still running" |
|
|
} |
|
|
else: |
|
|
return { |
|
|
"job_id": job_id, |
|
|
"status": "error", |
|
|
"error": job_info["error"] |
|
|
} |
|
|
|
|
|
@fastapi_app.get("/health") |
|
|
async def health_check(): |
|
|
return {"status": "healthy"} |
|
|
|
|
|
return fastapi_app |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import uvicorn |
|
|
|
|
|
|
|
|
@app.local_entrypoint() |
|
|
def main(): |
|
|
|
|
|
result = inference_worker.remote({"values": [1, 2, 3, 4]}, "test-job") |
|
|
print("Direct inference result:", result) |
|
|
|
|
|
|
|
|
print("API endpoints available:") |
|
|
print("POST /inference - Start inference job") |
|
|
print("GET /inference/{job_id} - Check job status") |
|
|
print("GET /health - Health check") |
|
|
|