File size: 5,469 Bytes
ba824e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""

External API client — generic adapter for third-party video/TTS endpoints.



Supports:

- Replicate (video generation)

- ElevenLabs (TTS)

- OpenAI (LLM fallback / TTS fallback)



Set env vars to enable each service.

"""

import os
import asyncio
import aiohttp
from typing import Optional


class ReplicateClient:
    """

    Replicate.com async client.

    Used as primary image-to-video source (e.g. stable-video-diffusion).

    """
    API_URL = "https://api.replicate.com/v1"
    API_KEY = os.getenv("REPLICATE_API_KEY", "")
    MODEL_VERSION = os.getenv(
        "REPLICATE_I2V_VERSION",
        "stability-ai/stable-video-diffusion:3f0457e4619daac51203dedb472816fd4af51f3149fa7a9e0b5ffcf1b8172438",
    )

    def __init__(self, timeout: int = 180):
        self.timeout = aiohttp.ClientTimeout(total=timeout)

    async def image_to_video(self, image_url: str) -> bytes:
        """

        Submit an image-to-video prediction and poll until complete.

        Returns raw video bytes.

        """
        headers = {
            "Authorization": f"Token {self.API_KEY}",
            "Content-Type": "application/json",
        }
        async with aiohttp.ClientSession(timeout=self.timeout) as session:
            # Create prediction
            async with session.post(
                f"{self.API_URL}/predictions",
                json={"version": self.MODEL_VERSION, "input": {"image": image_url}},
                headers=headers,
            ) as resp:
                resp.raise_for_status()
                prediction = await resp.json()

            pred_id = prediction["id"]

            # Poll
            for _ in range(60):
                await asyncio.sleep(3)
                async with session.get(
                    f"{self.API_URL}/predictions/{pred_id}",
                    headers=headers,
                ) as resp:
                    resp.raise_for_status()
                    pred = await resp.json()
                    if pred["status"] == "succeeded":
                        video_url = pred["output"]
                        break
                    if pred["status"] in ("failed", "canceled"):
                        raise RuntimeError(f"Replicate prediction failed: {pred.get('error')}")
            else:
                raise RuntimeError("Replicate prediction timed out.")

            # Download video
            async with session.get(video_url) as resp:
                resp.raise_for_status()
                return await resp.read()


class ElevenLabsClient:
    """ElevenLabs TTS async client."""
    API_URL = "https://api.elevenlabs.io/v1"
    API_KEY = os.getenv("ELEVENLABS_API_KEY", "")
    VOICE_ID = os.getenv("ELEVENLABS_VOICE_ID", "21m00Tcm4TlvDq8ikWAM")  # Rachel

    def __init__(self, timeout: int = 60):
        self.timeout = aiohttp.ClientTimeout(total=timeout)

    async def tts(self, text: str, model: str = "eleven_turbo_v2") -> bytes:
        """Returns MP3 audio bytes."""
        headers = {
            "xi-api-key": self.API_KEY,
            "Content-Type": "application/json",
        }
        payload = {
            "text": text,
            "model_id": model,
            "voice_settings": {"stability": 0.5, "similarity_boost": 0.75},
        }
        async with aiohttp.ClientSession(timeout=self.timeout) as session:
            async with session.post(
                f"{self.API_URL}/text-to-speech/{self.VOICE_ID}",
                json=payload,
                headers=headers,
            ) as resp:
                resp.raise_for_status()
                return await resp.read()


class OpenAIClient:
    """OpenAI fallback client for LLM and TTS."""
    API_URL = "https://api.openai.com/v1"
    API_KEY = os.getenv("OPENAI_API_KEY", "")

    def __init__(self, timeout: int = 60):
        self.timeout = aiohttp.ClientTimeout(total=timeout)

    def _headers(self):
        return {
            "Authorization": f"Bearer {self.API_KEY}",
            "Content-Type": "application/json",
        }

    async def chat(self, system_prompt: str, user_prompt: str, model: str = "gpt-4o-mini") -> str:
        """Returns assistant reply text."""
        payload = {
            "model": model,
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
            "temperature": 0.7,
        }
        async with aiohttp.ClientSession(timeout=self.timeout) as session:
            async with session.post(
                f"{self.API_URL}/chat/completions",
                json=payload,
                headers=self._headers(),
            ) as resp:
                resp.raise_for_status()
                data = await resp.json()
                return data["choices"][0]["message"]["content"]

    async def tts(self, text: str, voice: str = "alloy") -> bytes:
        """Returns MP3 bytes via OpenAI TTS."""
        payload = {"model": "tts-1", "input": text, "voice": voice}
        async with aiohttp.ClientSession(timeout=self.timeout) as session:
            async with session.post(
                f"{self.API_URL}/audio/speech",
                json=payload,
                headers=self._headers(),
            ) as resp:
                resp.raise_for_status()
                return await resp.read()