File size: 8,800 Bytes
91d209c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
"""
Replicate API endpoints
Handles video generation via Replicate's Python SDK

Based on standalone_video_creator.py flow:
- Uses replicate.run() for synchronous generation
- Sends prompt as stringified JSON (like the standalone script)
- Supports image input for frame continuity
"""

from fastapi import APIRouter, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import Optional, Dict, Any
import os
import asyncio
import uuid
import json
from concurrent.futures import ThreadPoolExecutor

router = APIRouter()

# Try importing replicate
try:
    import replicate
    REPLICATE_AVAILABLE = True
except ImportError:
    REPLICATE_AVAILABLE = False
    print("⚠️  Replicate package not installed. Run: pip install replicate")

# Thread pool for running blocking replicate.run() calls
executor = ThreadPoolExecutor(max_workers=4)

# In-memory store for prediction status (in production, use Redis)
predictions: Dict[str, Dict[str, Any]] = {}


# Request/Response Models
class ReplicateGenerateRequest(BaseModel):
    prompt: str
    imageUrl: Optional[str] = None
    model: Optional[str] = "google/veo-3"
    aspectRatio: Optional[str] = "9:16"
    seed: Optional[int] = None


class ReplicateGenerateResponse(BaseModel):
    id: str
    status: str


class ReplicateStatusResponse(BaseModel):
    status: str
    output: Optional[str] = None
    url: Optional[str] = None
    error: Optional[str] = None


def get_replicate_api_key():
    """Get Replicate API key from environment"""
    api_key = os.getenv('REPLICATE_API_TOKEN')
    if not api_key:
        raise HTTPException(
            status_code=500,
            detail="REPLICATE_API_TOKEN not configured. Add REPLICATE_API_TOKEN to .env.local"
        )
    return api_key


def run_replicate_sync(
    prediction_id: str,
    model: str,
    input_data: Dict[str, Any]
):
    """
    Run replicate.run() synchronously in a thread.
    Updates the predictions dict with status.
    
    This mirrors the standalone_video_creator.py approach.
    """
    try:
        # Set API token
        api_key = os.getenv('REPLICATE_API_TOKEN')
        os.environ['REPLICATE_API_TOKEN'] = api_key
        
        print(f"🎬 Running replicate.run('{model}')...")
        print(f"πŸ“¦ Input keys: {list(input_data.keys())}")
        
        # Run the model (blocking call)
        output = replicate.run(model, input=input_data)
        
        # Handle different output types (same as standalone_video_creator.py)
        video_url = None
        if isinstance(output, str):
            video_url = output
        elif hasattr(output, 'url'):
            # url is a property, not a method
            video_url = output.url
        elif hasattr(output, '__iter__'):
            # Could be a generator or list
            for item in output:
                if isinstance(item, str):
                    video_url = item
                    break
        else:
            video_url = str(output)
        
        print(f"βœ… Replicate completed: {video_url[:80] if video_url else 'no url'}...")
        
        predictions[prediction_id] = {
            "status": "succeeded",
            "url": video_url,
            "output": video_url,
            "error": None
        }
        
    except Exception as e:
        error_msg = str(e)
        print(f"❌ Replicate error: {error_msg}")
        
        predictions[prediction_id] = {
            "status": "failed",
            "url": None,
            "output": None,
            "error": error_msg
        }


@router.post("/replicate/generate", response_model=ReplicateGenerateResponse)
async def generate_video(request: ReplicateGenerateRequest, background_tasks: BackgroundTasks):
    """
    Generate video using Replicate Python SDK.
    
    Mirrors standalone_video_creator.py:
    - Uses replicate.run() 
    - Sends prompt as-is (frontend should send text prompt)
    - Supports image URL for frame continuity
    """
    if not REPLICATE_AVAILABLE:
        raise HTTPException(
            status_code=500,
            detail="Replicate package not installed. Run: pip install replicate"
        )
    
    try:
        # Verify API key is set
        get_replicate_api_key()
        
        model_id = request.model or "google/veo-3"
        
        # Build input params (matching standalone_video_creator.py)
        input_data: Dict[str, Any] = {
            "prompt": request.prompt,
        }
        
        # Add aspect ratio
        if request.aspectRatio:
            input_data["aspect_ratio"] = request.aspectRatio
        
        # Add seed if provided
        if request.seed is not None:
            input_data["seed"] = request.seed
        
        # Add image URL if provided
        if request.imageUrl:
            input_data["image"] = request.imageUrl
        
        print(f"🎬 Starting Replicate generation with model: {model_id}")
        print(f"πŸ“ Prompt: {request.prompt[:100]}...")
        if request.imageUrl:
            print(f"πŸ–ΌοΈ Using reference image: {request.imageUrl[:50]}...")
        print(f"βš™οΈ Input params: {list(input_data.keys())}")
        
        # Create prediction ID
        prediction_id = f"rep_{uuid.uuid4().hex[:12]}"
        
        # Initialize prediction status
        predictions[prediction_id] = {
            "status": "processing",
            "url": None,
            "output": None,
            "error": None
        }
        
        # Run in background thread (replicate.run() is blocking)
        loop = asyncio.get_event_loop()
        loop.run_in_executor(
            executor,
            run_replicate_sync,
            prediction_id,
            model_id,
            input_data
        )
        
        return ReplicateGenerateResponse(
            id=prediction_id,
            status="processing"
        )
    
    except HTTPException:
        raise
    except Exception as e:
        print(f"❌ Replicate generation error: {str(e)}")
        import traceback
        traceback.print_exc()
        raise HTTPException(
            status_code=500,
            detail=f"Replicate generation failed: {str(e)}"
        )


@router.get("/replicate/status/{prediction_id}", response_model=ReplicateStatusResponse)
async def get_prediction_status(prediction_id: str):
    """
    Get the status of a Replicate prediction.
    """
    if prediction_id not in predictions:
        raise HTTPException(
            status_code=404,
            detail=f"Prediction not found: {prediction_id}"
        )
    
    pred = predictions[prediction_id]
    
    return ReplicateStatusResponse(
        status=pred["status"],
        output=pred.get("output"),
        url=pred.get("url"),
        error=pred.get("error")
    )


@router.get("/replicate/models")
async def list_available_models():
    """List available video generation models"""
    return {
        "models": [
            {
                "id": "google/veo-3",
                "name": "Google Veo 3 (Recommended)",
                "description": "High-quality text/image-to-video generation",
                "type": "text-to-video",
                "supports_image": True
            },
            {
                "id": "minimax/video-01",
                "name": "MiniMax Video-01",
                "description": "High-quality text-to-video generation",
                "type": "text-to-video",
                "supports_image": True
            },
            {
                "id": "luma/ray",
                "name": "Luma Ray",
                "description": "Cinematic video generation",
                "type": "text-to-video",
                "supports_image": True
            }
        ]
    }


@router.post("/replicate/cancel/{prediction_id}")
async def cancel_prediction(prediction_id: str):
    """Cancel a running prediction (marks as cancelled in our store)"""
    if prediction_id in predictions:
        predictions[prediction_id]["status"] = "failed"
        predictions[prediction_id]["error"] = "Cancelled by user"
    
    return JSONResponse(
        status_code=200,
        content={"message": "Prediction cancelled", "id": prediction_id}
    )


@router.get("/replicate/health")
async def check_replicate_health():
    """Check if Replicate is configured"""
    api_key = os.getenv('REPLICATE_API_TOKEN')
    return {
        "configured": bool(api_key),
        "package_installed": REPLICATE_AVAILABLE,
        "message": "Replicate is ready" if (api_key and REPLICATE_AVAILABLE)
                   else "Missing: " + (
                       "REPLICATE_API_TOKEN" if not api_key else ""
                   ) + (
                       " replicate package" if not REPLICATE_AVAILABLE else ""
                   )
    }