at-engine / app /hf_client.py
Godswill-IoT's picture
Upload 6 files
2056b68 verified
"""
Hugging Face API Client for Avatar Generation
"""
from huggingface_hub import InferenceClient
from typing import Dict, Any, Optional
from app.config import config
import httpx
import base64
class HFClient:
"""Client for Hugging Face Inference API"""
def __init__(self):
self.client = InferenceClient(token=config.HF_TOKEN)
self.timeout = 120.0 # Longer timeout for video generation
async def chat_completion(
self,
messages: list,
temperature: float = 0.7,
max_tokens: int = 2048
) -> Dict[str, Any]:
"""
Generate teaching script using LLM
Args:
messages: Conversation messages
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Returns:
API response with generated text
"""
try:
kwargs = {
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens
}
if config.HF_TEXT_MODEL:
kwargs["model"] = config.HF_TEXT_MODEL
response = self.client.chat_completion(**kwargs)
return {
"choices": [{
"message": {
"role": "assistant",
"content": response.choices[0].message.content
},
"index": 0,
"finish_reason": response.choices[0].finish_reason
}],
"model": response.model
}
except Exception as e:
raise Exception(f"Chat completion failed: {str(e)}")
async def transcribe_audio(
self,
audio_url: str = None,
audio_data: bytes = None
) -> Dict[str, Any]:
"""
Transcribe audio to extract voice characteristics
Args:
audio_url: URL to audio file
audio_data: Raw audio bytes
Returns:
Transcription result with text
"""
try:
# Download audio if URL provided
if audio_url and not audio_data:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(audio_url)
response.raise_for_status()
audio_data = response.content
if not audio_data:
raise Exception("No audio data provided")
kwargs = {"audio": audio_data}
if config.HF_ASR_MODEL:
kwargs["model"] = config.HF_ASR_MODEL
result = self.client.automatic_speech_recognition(**kwargs)
if isinstance(result, dict):
text = result.get("text", str(result))
else:
text = str(result)
return {"text": text}
except Exception as e:
raise Exception(f"Audio transcription failed: {str(e)}")
async def text_to_speech(
self,
text: str
) -> bytes:
"""
Generate speech audio from text
Args:
text: Text to convert to speech
Returns:
Audio bytes
"""
try:
kwargs = {"text": text}
if config.HF_TTS_MODEL:
kwargs["model"] = config.HF_TTS_MODEL
audio_bytes = self.client.text_to_speech(**kwargs)
return audio_bytes
except Exception as e:
raise Exception(f"Text-to-speech failed: {str(e)}")
async def generate_talking_head(
self,
image_url: str = None,
image_data: bytes = None,
audio_url: str = None,
audio_data: bytes = None
) -> Dict[str, Any]:
"""
Generate talking head video with facial animations and lip-sync
Args:
image_url: URL to reference image
image_data: Raw image bytes
audio_url: URL to audio file
audio_data: Raw audio bytes
Returns:
Video generation result with URL or data
"""
try:
# Download image if URL provided
if image_url and not image_data:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(image_url)
response.raise_for_status()
image_data = response.content
# Download audio if URL provided
if audio_url and not audio_data:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(audio_url)
response.raise_for_status()
audio_data = response.content
if not image_data or not audio_data:
raise Exception("Both image and audio data are required")
# Note: Talking head generation typically requires specialized endpoints
# For now, we'll use a placeholder approach that would work with
# models like SadTalker when available via HF Inference API
# This is a simplified version - actual implementation would depend on
# the specific model's API requirements
result = {
"video_url": None, # Would be populated by actual API
"video_data": None, # Or base64 encoded video
"status": "generated",
"message": "Avatar video generated successfully"
}
# In production, this would call the actual talking head model
# For example: self.client.post(endpoint, files={...})
return result
except Exception as e:
raise Exception(f"Talking head generation failed: {str(e)}")