jebin2's picture
refactor: remove worker pool, use direct fal.ai API calls
693e4e3
"""
Fal.ai Service for video generation.
Python implementation using fal-client SDK.
Uses server-side API key from environment.
"""
import asyncio
import logging
import os
from typing import Optional, Literal
logger = logging.getLogger(__name__)
# Model names - easily configurable
MODELS = {
"video_generation": "fal-ai/veo3.1/fast/image-to-video"
}
# Type aliases
AspectRatio = Literal["16:9", "9:16", "auto"]
Resolution = Literal["720p", "1080p"]
# Mock mode for local testing (set FAL_MOCK_MODE=true to skip real API calls)
MOCK_MODE = os.getenv("FAL_MOCK_MODE", "false").lower() == "true"
# Sample video URL for mock mode
MOCK_VIDEO_URL = "https://v3b.fal.media/files/mock/mock-video.mp4"
def get_fal_api_key() -> str:
"""Get Fal.ai API key from environment."""
api_key = os.getenv("FAL_KEY")
if not api_key:
raise ValueError("Server Authentication Error: FAL_KEY not configured")
return api_key
class FalService:
"""
Fal.ai Service for video generation.
Uses server-side API key from environment (FAL_KEY).
"""
def __init__(self, api_key: Optional[str] = None):
"""Initialize the Fal.ai client with API key from env or provided."""
self.api_key = api_key or get_fal_api_key()
# fal_client reads from FAL_KEY env var by default
# Set it explicitly if a custom key is provided
if api_key:
os.environ["FAL_KEY"] = api_key
def _handle_api_error(self, error: Exception, context: str):
"""Handle API errors with descriptive messages."""
msg = str(error)
if "401" in msg or "Unauthorized" in msg:
raise ValueError(
f"Authentication failed ({context}). Check your FAL_KEY is valid."
)
if "402" in msg or "Payment Required" in msg:
raise ValueError(
f"Insufficient credits ({context}). Add credits at fal.ai."
)
if "429" in msg or "Rate limit" in msg.lower():
raise ValueError(
f"Rate limit exceeded ({context}). Wait and retry."
)
raise error
async def start_video_generation(
self,
base64_image: str,
mime_type: str,
prompt: str,
aspect_ratio: AspectRatio = "16:9",
resolution: Resolution = "720p",
number_of_videos: int = 1
) -> dict:
"""
Start video generation using Fal.ai Veo 3.1 model.
Unlike Gemini, fal.ai subscribe() handles polling internally,
so this returns the completed video directly.
Returns dict with:
- fal_request_id: Request ID for reference
- done: Always True (fal.ai waits for completion)
- status: "completed" or "failed"
- video_url: URL to the generated video
"""
# Mock mode for testing without API credits
if MOCK_MODE:
import uuid
mock_request_id = f"mock_fal_{uuid.uuid4().hex[:16]}"
logger.info(f"[MOCK MODE] Video generation: {mock_request_id}")
await asyncio.sleep(2) # Simulate API delay
return {
"fal_request_id": mock_request_id,
"done": True,
"status": "completed",
"video_url": MOCK_VIDEO_URL
}
try:
import fal_client
# Use submit() instead of subscribe() - returns immediately without waiting
# This starts the job and returns a request_id for status checking
handle = await asyncio.to_thread(
fal_client.submit,
MODELS["video_generation"],
arguments={
"prompt": prompt,
"image_url": f"data:{mime_type};base64,{base64_image}",
"aspect_ratio": aspect_ratio,
"resolution": resolution,
"generate_audio": True,
},
)
# Get the request ID from the handle
request_id = handle.request_id if hasattr(handle, 'request_id') else str(handle)
return {
"fal_request_id": request_id,
"done": False,
"status": "processing",
}
except ImportError:
raise ValueError(
"fal-client package not installed. Run: pip install fal-client"
)
except Exception as error:
self._handle_api_error(error, MODELS["video_generation"])
async def check_video_status(self, fal_request_id: str) -> dict:
"""
Check the status of a video generation request.
Returns immediately with current status (does not wait).
"""
# Mock mode for testing
if MOCK_MODE:
import random
# Simulate completion after a few checks
if random.random() > 0.7:
return {
"fal_request_id": fal_request_id,
"done": True,
"status": "completed",
"video_url": MOCK_VIDEO_URL
}
return {
"fal_request_id": fal_request_id,
"done": False,
"status": "processing"
}
try:
import fal_client
# Get status without waiting
status = await asyncio.to_thread(
fal_client.status,
MODELS["video_generation"],
fal_request_id,
with_logs=False
)
# Check if completed
if hasattr(status, 'status'):
if status.status == "COMPLETED":
# Get the result
result = await asyncio.to_thread(
fal_client.result,
MODELS["video_generation"],
fal_request_id
)
# Extract video URL
video_url = None
if isinstance(result, dict) and "video" in result:
video_url = result["video"].get("url")
elif hasattr(result, "video") and hasattr(result.video, "url"):
video_url = result.video.url
return {
"fal_request_id": fal_request_id,
"done": True,
"status": "completed",
"video_url": video_url
}
elif status.status == "FAILED":
return {
"fal_request_id": fal_request_id,
"done": True,
"status": "failed",
"error": getattr(status, 'error', 'Unknown error')
}
else:
# Still processing (IN_QUEUE, IN_PROGRESS)
return {
"fal_request_id": fal_request_id,
"done": False,
"status": "processing"
}
# Fallback - assume still processing
return {
"fal_request_id": fal_request_id,
"done": False,
"status": "processing"
}
except ImportError:
raise ValueError(
"fal-client package not installed. Run: pip install fal-client"
)
except Exception as error:
logger.error(f"Error checking status for {fal_request_id}: {error}")
return {
"fal_request_id": fal_request_id,
"done": False,
"status": "processing",
"error": str(error)
}
async def download_video(self, video_url: str, request_id: str) -> str:
"""
Download video from fal.ai to local storage.
Returns the local filename.
"""
import httpx
# Use same downloads directory as Gemini service
downloads_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"downloads"
)
os.makedirs(downloads_dir, exist_ok=True)
filename = f"{request_id}.mp4"
filepath = os.path.join(downloads_dir, filename)
try:
async with httpx.AsyncClient(timeout=120.0, follow_redirects=True) as client:
response = await client.get(video_url)
response.raise_for_status()
with open(filepath, 'wb') as f:
f.write(response.content)
logger.info(f"Downloaded video to {filepath}")
return filename
except Exception as e:
logger.error(f"Failed to download video: {e}")
raise ValueError(f"Failed to download video: {e}")