language-learn / backend /hf_client.py
swapmyface's picture
Fix: use correct HF router endpoint and InferenceClient for TTS/STT
96ea3f8 verified
"""HuggingFace Inference API wrapper for LLM, TTS, and STT."""
import os
import io
import logging
import requests
logger = logging.getLogger(__name__)
HF_TOKEN = os.environ.get("HF_TOKEN", "")
CHAT_API_URL = "https://router.huggingface.co/v1/chat/completions"
PRIMARY_MODEL = "Qwen/Qwen2.5-72B-Instruct"
FALLBACK_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
STT_MODEL = "openai/whisper-base"
_inference_client = None
def _get_client():
"""Lazy-init the HF InferenceClient."""
global _inference_client
if _inference_client is None:
from huggingface_hub import InferenceClient
_inference_client = InferenceClient(token=HF_TOKEN)
return _inference_client
def chat_completion(messages, max_tokens=1024, temperature=0.7):
"""Send chat completion request via the OpenAI-compatible endpoint."""
headers = {
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json"
}
payload = {
"model": PRIMARY_MODEL,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": False
}
for model in [PRIMARY_MODEL, FALLBACK_MODEL]:
try:
payload["model"] = model
resp = requests.post(CHAT_API_URL, headers=headers, json=payload, timeout=60)
resp.raise_for_status()
data = resp.json()
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
if content:
return {"content": content, "model": model}
except Exception as e:
logger.warning(f"Model {model} failed: {e}")
continue
return {"content": "I'm sorry, I'm having trouble connecting right now. Please try again.", "model": "fallback"}
def text_to_speech(text, tts_model="facebook/mms-tts-hin"):
"""Convert text to speech audio bytes using HF InferenceClient."""
if not text or not text.strip():
return None
tts_text = text[:500]
try:
client = _get_client()
audio_bytes = client.text_to_speech(tts_text, model=tts_model)
if isinstance(audio_bytes, bytes) and len(audio_bytes) > 100:
return audio_bytes
return None
except Exception as e:
logger.warning(f"TTS failed for model {tts_model}: {e}")
return None
def speech_to_text(audio_bytes):
"""Transcribe audio to text using HF InferenceClient."""
if not audio_bytes:
return ""
try:
client = _get_client()
result = client.automatic_speech_recognition(audio_bytes, model=STT_MODEL)
if isinstance(result, dict):
return result.get("text", "")
if hasattr(result, "text"):
return result.text
return str(result) if result else ""
except Exception as e:
logger.warning(f"STT failed: {e}")
return ""
def get_model_info():
"""Return info about the models being used."""
return {
"llm": PRIMARY_MODEL,
"llm_fallback": FALLBACK_MODEL,
"stt": STT_MODEL,
"tts": "facebook/mms-tts-*"
}