video_contents_generator / services /external_api_client.py
babaTEEpe's picture
Upload 29 files
ba824e8 verified
"""
External API client — generic adapter for third-party video/TTS endpoints.
Supports:
- Replicate (video generation)
- ElevenLabs (TTS)
- OpenAI (LLM fallback / TTS fallback)
Set env vars to enable each service.
"""
import os
import asyncio
import aiohttp
from typing import Optional
class ReplicateClient:
"""
Replicate.com async client.
Used as primary image-to-video source (e.g. stable-video-diffusion).
"""
API_URL = "https://api.replicate.com/v1"
API_KEY = os.getenv("REPLICATE_API_KEY", "")
MODEL_VERSION = os.getenv(
"REPLICATE_I2V_VERSION",
"stability-ai/stable-video-diffusion:3f0457e4619daac51203dedb472816fd4af51f3149fa7a9e0b5ffcf1b8172438",
)
def __init__(self, timeout: int = 180):
self.timeout = aiohttp.ClientTimeout(total=timeout)
async def image_to_video(self, image_url: str) -> bytes:
"""
Submit an image-to-video prediction and poll until complete.
Returns raw video bytes.
"""
headers = {
"Authorization": f"Token {self.API_KEY}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession(timeout=self.timeout) as session:
# Create prediction
async with session.post(
f"{self.API_URL}/predictions",
json={"version": self.MODEL_VERSION, "input": {"image": image_url}},
headers=headers,
) as resp:
resp.raise_for_status()
prediction = await resp.json()
pred_id = prediction["id"]
# Poll
for _ in range(60):
await asyncio.sleep(3)
async with session.get(
f"{self.API_URL}/predictions/{pred_id}",
headers=headers,
) as resp:
resp.raise_for_status()
pred = await resp.json()
if pred["status"] == "succeeded":
video_url = pred["output"]
break
if pred["status"] in ("failed", "canceled"):
raise RuntimeError(f"Replicate prediction failed: {pred.get('error')}")
else:
raise RuntimeError("Replicate prediction timed out.")
# Download video
async with session.get(video_url) as resp:
resp.raise_for_status()
return await resp.read()
class ElevenLabsClient:
"""ElevenLabs TTS async client."""
API_URL = "https://api.elevenlabs.io/v1"
API_KEY = os.getenv("ELEVENLABS_API_KEY", "")
VOICE_ID = os.getenv("ELEVENLABS_VOICE_ID", "21m00Tcm4TlvDq8ikWAM") # Rachel
def __init__(self, timeout: int = 60):
self.timeout = aiohttp.ClientTimeout(total=timeout)
async def tts(self, text: str, model: str = "eleven_turbo_v2") -> bytes:
"""Returns MP3 audio bytes."""
headers = {
"xi-api-key": self.API_KEY,
"Content-Type": "application/json",
}
payload = {
"text": text,
"model_id": model,
"voice_settings": {"stability": 0.5, "similarity_boost": 0.75},
}
async with aiohttp.ClientSession(timeout=self.timeout) as session:
async with session.post(
f"{self.API_URL}/text-to-speech/{self.VOICE_ID}",
json=payload,
headers=headers,
) as resp:
resp.raise_for_status()
return await resp.read()
class OpenAIClient:
"""OpenAI fallback client for LLM and TTS."""
API_URL = "https://api.openai.com/v1"
API_KEY = os.getenv("OPENAI_API_KEY", "")
def __init__(self, timeout: int = 60):
self.timeout = aiohttp.ClientTimeout(total=timeout)
def _headers(self):
return {
"Authorization": f"Bearer {self.API_KEY}",
"Content-Type": "application/json",
}
async def chat(self, system_prompt: str, user_prompt: str, model: str = "gpt-4o-mini") -> str:
"""Returns assistant reply text."""
payload = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": 0.7,
}
async with aiohttp.ClientSession(timeout=self.timeout) as session:
async with session.post(
f"{self.API_URL}/chat/completions",
json=payload,
headers=self._headers(),
) as resp:
resp.raise_for_status()
data = await resp.json()
return data["choices"][0]["message"]["content"]
async def tts(self, text: str, voice: str = "alloy") -> bytes:
"""Returns MP3 bytes via OpenAI TTS."""
payload = {"model": "tts-1", "input": text, "voice": voice}
async with aiohttp.ClientSession(timeout=self.timeout) as session:
async with session.post(
f"{self.API_URL}/audio/speech",
json=payload,
headers=self._headers(),
) as resp:
resp.raise_for_status()
return await resp.read()