added HeyGen
Browse files- App/TTS/Schemas.py +25 -5
- App/TTS/TTSRoutes.py +13 -7
- App/TTS/utils/HeyGen.py +82 -0
App/TTS/Schemas.py
CHANGED
|
@@ -1,23 +1,43 @@
|
|
| 1 |
-
from pydantic import BaseModel,Field
|
| 2 |
-
from typing import List,Optional
|
| 3 |
import uuid
|
| 4 |
|
|
|
|
| 5 |
class Speak(BaseModel):
|
| 6 |
paragraphId: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
| 7 |
speaker: str
|
| 8 |
text: str
|
| 9 |
-
voiceId: str = Field(
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def __init__(self, **data):
|
| 12 |
-
data["text"] =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
super().__init__(**data)
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
class TTSGenerateRequest(BaseModel):
|
| 18 |
paragraphs: List[Speak]
|
| 19 |
requestId: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
| 20 |
-
workspaceId: str =Field(default_factory=lambda: str(uuid.uuid4()))
|
| 21 |
|
| 22 |
|
| 23 |
class StatusRequest(BaseModel):
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field, validator
|
| 2 |
+
from typing import List, Optional
|
| 3 |
import uuid
|
| 4 |
|
| 5 |
+
|
| 6 |
class Speak(BaseModel):
|
| 7 |
paragraphId: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
| 8 |
speaker: str
|
| 9 |
text: str
|
| 10 |
+
voiceId: str = Field(
|
| 11 |
+
default="c60166365edf46589657770d", alias="speaker"
|
| 12 |
+
) # Default speaker value
|
| 13 |
|
| 14 |
def __init__(self, **data):
|
| 15 |
+
data["text"] = (
|
| 16 |
+
data.get("text")
|
| 17 |
+
if "<speak>" in data.get("text")
|
| 18 |
+
else f"<speak>{data.get('text')}</speak>"
|
| 19 |
+
)
|
| 20 |
super().__init__(**data)
|
| 21 |
|
| 22 |
|
| 23 |
+
class HeyGenTTSRequest(BaseModel):
|
| 24 |
+
voice_id: str = Field(default="d7bbcdd6964c47bdaae26decade4a933")
|
| 25 |
+
rate: str = Field(default="1")
|
| 26 |
+
pitch: str = Field(default="-3%")
|
| 27 |
+
text: str = "Sample"
|
| 28 |
+
|
| 29 |
+
@validator("text")
|
| 30 |
+
def validate_age(cls, value, values):
|
| 31 |
+
if not "speak" in value:
|
| 32 |
+
return f'<speak> <voice name="{values.get("voice_id")}"><prosody rate="{values.get("rate")}" pitch="{values.get("pitch")}">{value}</prosody></voice></speak>'
|
| 33 |
+
else:
|
| 34 |
+
return value
|
| 35 |
+
|
| 36 |
|
| 37 |
class TTSGenerateRequest(BaseModel):
|
| 38 |
paragraphs: List[Speak]
|
| 39 |
requestId: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
| 40 |
+
workspaceId: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
| 41 |
|
| 42 |
|
| 43 |
class StatusRequest(BaseModel):
|
App/TTS/TTSRoutes.py
CHANGED
|
@@ -1,27 +1,33 @@
|
|
| 1 |
from fastapi import APIRouter
|
| 2 |
|
| 3 |
|
| 4 |
-
from .Schemas import StatusRequest, TTSGenerateRequest
|
| 5 |
from .utils.Podcastle import PodcastleAPI
|
|
|
|
| 6 |
import os
|
| 7 |
|
| 8 |
tts_router = APIRouter(tags=["TTS"])
|
| 9 |
data = {"username": os.environ.get("USERNAME"), "password": os.environ.get("PASSWORD")}
|
| 10 |
tts = PodcastleAPI(**data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
-
#
|
| 14 |
@tts_router.post("/generate_tts")
|
| 15 |
async def generate_voice(req: TTSGenerateRequest):
|
| 16 |
print("here --entered!")
|
| 17 |
return await tts.make_request(req)
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
@tts_router.post("/status")
|
| 21 |
async def search_id(req: StatusRequest):
|
| 22 |
return await tts.check_status(req)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# @tts_router.post("/search_text")
|
| 26 |
-
# async def search_text(req: SearchRequest):
|
| 27 |
-
# return TextSearch(query=req.query)
|
|
|
|
| 1 |
from fastapi import APIRouter
|
| 2 |
|
| 3 |
|
| 4 |
+
from .Schemas import StatusRequest, TTSGenerateRequest, HeyGenTTSRequest
|
| 5 |
from .utils.Podcastle import PodcastleAPI
|
| 6 |
+
from .utils.HeyGen import HeygenAPI
|
| 7 |
import os
|
| 8 |
|
| 9 |
tts_router = APIRouter(tags=["TTS"])
|
| 10 |
data = {"username": os.environ.get("USERNAME"), "password": os.environ.get("PASSWORD")}
|
| 11 |
tts = PodcastleAPI(**data)
|
| 12 |
+
data = {
|
| 13 |
+
"account": os.environ.get("HEYGEN_USERNAME"),
|
| 14 |
+
"password": os.environ.get("HEYGEN_PASSWORD"),
|
| 15 |
+
}
|
| 16 |
+
heyGentts = HeygenAPI(**data)
|
| 17 |
|
| 18 |
|
|
|
|
| 19 |
@tts_router.post("/generate_tts")
|
| 20 |
async def generate_voice(req: TTSGenerateRequest):
|
| 21 |
print("here --entered!")
|
| 22 |
return await tts.make_request(req)
|
| 23 |
|
| 24 |
|
| 25 |
+
@tts_router.post("/heygen_tts")
|
| 26 |
+
async def generate_heygen_voice(req: HeyGenTTSRequest):
|
| 27 |
+
print("hey gen here")
|
| 28 |
+
return await heyGentts.tts_request(req)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
@tts_router.post("/status")
|
| 32 |
async def search_id(req: StatusRequest):
|
| 33 |
return await tts.check_status(req)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
App/TTS/utils/HeyGen.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import aiohttp
|
| 2 |
+
from App.TTS.Schemas import HeyGenTTSRequest
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class HeygenAPI:
|
| 6 |
+
def __init__(self, account, password):
|
| 7 |
+
self.base_url = "https://api2.heygen.com/v1"
|
| 8 |
+
self.account = account
|
| 9 |
+
self.password = password
|
| 10 |
+
self.session = None
|
| 11 |
+
self.session_token = None
|
| 12 |
+
|
| 13 |
+
async def create_session(self):
|
| 14 |
+
self.session = aiohttp.ClientSession()
|
| 15 |
+
|
| 16 |
+
async def close_session(self):
|
| 17 |
+
if self.session:
|
| 18 |
+
await self.session.close()
|
| 19 |
+
|
| 20 |
+
async def login(self):
|
| 21 |
+
url = f"{self.base_url}/pacific/login"
|
| 22 |
+
payload = {
|
| 23 |
+
"login_type": "email",
|
| 24 |
+
"account": self.account,
|
| 25 |
+
"password": self.password,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
if not self.session:
|
| 29 |
+
await self.create_session()
|
| 30 |
+
|
| 31 |
+
async with self.session.post(url, json=payload) as response:
|
| 32 |
+
response_data = await response.json()
|
| 33 |
+
self.session_token = response_data.get("data", {}).get("session_token")
|
| 34 |
+
return response_data
|
| 35 |
+
|
| 36 |
+
async def relogin(self):
|
| 37 |
+
# Function to relogin and update the session token
|
| 38 |
+
login_result = await self.login()
|
| 39 |
+
if login_result.get("code") == 100:
|
| 40 |
+
self.session_token = login_result["data"]["session_token"]
|
| 41 |
+
return True
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
async def tts_request(self, req: HeyGenTTSRequest):
|
| 45 |
+
if not self.session_token or not self.session_token:
|
| 46 |
+
await self.login()
|
| 47 |
+
|
| 48 |
+
url = f"{self.base_url}/online/text_to_speech.generate"
|
| 49 |
+
headers = {
|
| 50 |
+
"content-type": "application/json",
|
| 51 |
+
"x-session-token": self.session_token,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
tts_payload = {
|
| 55 |
+
"text_type": "ssml",
|
| 56 |
+
"output_format": "wav",
|
| 57 |
+
"text": req.text,
|
| 58 |
+
"voice_id": req.voice_id,
|
| 59 |
+
"settings": {},
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
async with self.session.post(
|
| 63 |
+
url, json=tts_payload, headers=headers
|
| 64 |
+
) as response:
|
| 65 |
+
if response.status == 401:
|
| 66 |
+
# If a 401 error is encountered, relogin and retry the request
|
| 67 |
+
if await self.relogin():
|
| 68 |
+
headers["x-session-token"] = self.session_token
|
| 69 |
+
response = await self.session.post(
|
| 70 |
+
url, json=tts_payload, headers=headers
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
response_data = await response.json()
|
| 74 |
+
return response_data
|
| 75 |
+
|
| 76 |
+
async def __aenter__(self):
|
| 77 |
+
if not self.session:
|
| 78 |
+
await self.create_session()
|
| 79 |
+
return self
|
| 80 |
+
|
| 81 |
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
| 82 |
+
await self.close_session()
|