File size: 4,197 Bytes
75da08b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# modal_app.py
import modal
import time
import uuid
from typing import Dict, Any
from fastapi import FastAPI, HTTPException
import asyncio

# Initialize Modal app
app = modal.App("model-inference-app")
image = modal.Image.debian_slim().pip_install("torch", "fastapi", "uvicorn")

# Global storage for job tracking (in production, use Redis or database)
job_storage: Dict[str, Dict[str, Any]] = {}

# GPU app for model inference
@app.function(
    image=image,
    gpu="A10G",
    timeout=300,  # 5 minutes timeout
    retries=1
)
def inference_worker(input_data: Dict[str, Any], job_id: str) -> Dict[str, Any]:
    """
    GPU-based inference worker
    """
    import torch
    import time
    
    # Simulate model loading (replace with your actual model)
    print(f"Loading model for job {job_id}")
    
    # Simulate some computation time
    time.sleep(2)
    
    # Your actual model inference logic here
    # For this example, we'll simulate a simple operation
    result = {
        "input": input_data,
        "processed": True,
        "timestamp": time.time(),
        "job_id": job_id,
        "result": [x * 2 for x in input_data.get("values", [])]  # Example processing
    }
    
    # Store result
    job_storage[job_id] = {
        "status": "completed",
        "result": result,
        "completed_at": time.time()
    }
    
    return result

# FastAPI app for REST endpoints
@app.function(image=image)
@modal.asgi_app()
def api():
    fastapi_app = FastAPI()
    
    @fastapi_app.post("/inference")
    async def start_inference(input_data: Dict[str, Any]):
        """
        Trigger inference job and return job ID
        """
        job_id = str(uuid.uuid4())
        
        # Store job in progress
        job_storage[job_id] = {
            "status": "processing",
            "input": input_data,
            "started_at": time.time()
        }
        
        # Start inference in background
        try:
            # This will run asynchronously on GPU
            inference_worker.spawn(input_data, job_id)
            
            return {
                "job_id": job_id,
                "status": "processing",
                "message": "Inference started"
            }
        except Exception as e:
            job_storage[job_id] = {
                "status": "error",
                "error": str(e),
                "completed_at": time.time()
            }
            raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
    
    @fastapi_app.get("/inference/{job_id}")
    async def get_inference_status(job_id: str):
        """
        Get status and results of inference job
        """
        if job_id not in job_storage:
            raise HTTPException(status_code=404, detail="Job not found")
        
        job_info = job_storage[job_id]
        status = job_info["status"]
        
        if status == "completed":
            return {
                "job_id": job_id,
                "status": "completed",
                "result": job_info["result"]
            }
        elif status == "processing":
            return {
                "job_id": job_id,
                "status": "processing",
                "message": "Inference is still running"
            }
        else:
            return {
                "job_id": job_id,
                "status": "error",
                "error": job_info["error"]
            }
    
    @fastapi_app.get("/health")
    async def health_check():
        return {"status": "healthy"}
    
    return fastapi_app

# For local testing
if __name__ == "__main__":
    # This allows you to test locally
    import uvicorn
    
    # Create a simple test
    @app.local_entrypoint()
    def main():
        # Test the inference function directly
        result = inference_worker.remote({"values": [1, 2, 3, 4]}, "test-job")
        print("Direct inference result:", result)
        
        # Test the API endpoints (you'd normally test this via HTTP calls)
        print("API endpoints available:")
        print("POST /inference - Start inference job")
        print("GET /inference/{job_id} - Check job status")
        print("GET /health - Health check")