swapmyface commited on
Commit
96ea3f8
·
verified ·
1 Parent(s): 07bc9aa

Fix: use correct HF router endpoint and InferenceClient for TTS/STT

Browse files
Files changed (1) hide show
  1. backend/hf_client.py +29 -47
backend/hf_client.py CHANGED
@@ -1,34 +1,37 @@
1
  """HuggingFace Inference API wrapper for LLM, TTS, and STT."""
2
 
3
  import os
 
4
  import logging
5
  import requests
6
- import time
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
11
- API_BASE = "https://router.huggingface.co/hf-inference/models"
12
 
 
13
  PRIMARY_MODEL = "Qwen/Qwen2.5-72B-Instruct"
14
  FALLBACK_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
15
  STT_MODEL = "openai/whisper-base"
16
 
17
- HEADERS = {
18
- "Authorization": f"Bearer {HF_TOKEN}",
19
- "Content-Type": "application/json"
20
- }
21
 
22
 
23
- def _get_headers(content_type="application/json"):
24
- return {
25
- "Authorization": f"Bearer {HF_TOKEN}",
26
- "Content-Type": content_type
27
- }
 
 
28
 
29
 
30
  def chat_completion(messages, max_tokens=1024, temperature=0.7):
31
- """Send chat completion request to HF Inference API."""
 
 
 
 
32
  payload = {
33
  "model": PRIMARY_MODEL,
34
  "messages": messages,
@@ -40,12 +43,7 @@ def chat_completion(messages, max_tokens=1024, temperature=0.7):
40
  for model in [PRIMARY_MODEL, FALLBACK_MODEL]:
41
  try:
42
  payload["model"] = model
43
- resp = requests.post(
44
- "https://router.huggingface.co/hf-inference/v1/chat/completions",
45
- headers=_get_headers(),
46
- json=payload,
47
- timeout=60
48
- )
49
  resp.raise_for_status()
50
  data = resp.json()
51
  content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
@@ -59,24 +57,17 @@ def chat_completion(messages, max_tokens=1024, temperature=0.7):
59
 
60
 
61
  def text_to_speech(text, tts_model="facebook/mms-tts-hin"):
62
- """Convert text to speech audio bytes."""
63
  if not text or not text.strip():
64
  return None
65
 
66
- # Truncate very long text for TTS
67
  tts_text = text[:500]
68
 
69
  try:
70
- resp = requests.post(
71
- f"{API_BASE}/{tts_model}",
72
- headers=_get_headers(),
73
- json={"inputs": tts_text},
74
- timeout=30
75
- )
76
- resp.raise_for_status()
77
- content_type = resp.headers.get("content-type", "")
78
- if "audio" in content_type or len(resp.content) > 1000:
79
- return resp.content
80
  return None
81
  except Exception as e:
82
  logger.warning(f"TTS failed for model {tts_model}: {e}")
@@ -84,27 +75,18 @@ def text_to_speech(text, tts_model="facebook/mms-tts-hin"):
84
 
85
 
86
  def speech_to_text(audio_bytes):
87
- """Transcribe audio to text using Whisper."""
88
  if not audio_bytes:
89
  return ""
90
 
91
  try:
92
- resp = requests.post(
93
- f"{API_BASE}/{STT_MODEL}",
94
- headers={
95
- "Authorization": f"Bearer {HF_TOKEN}",
96
- "Content-Type": "audio/wav"
97
- },
98
- data=audio_bytes,
99
- timeout=30
100
- )
101
- resp.raise_for_status()
102
- data = resp.json()
103
- if isinstance(data, dict):
104
- return data.get("text", "")
105
- if isinstance(data, list) and data:
106
- return data[0].get("text", "")
107
- return ""
108
  except Exception as e:
109
  logger.warning(f"STT failed: {e}")
110
  return ""
 
1
  """HuggingFace Inference API wrapper for LLM, TTS, and STT."""
2
 
3
  import os
4
+ import io
5
  import logging
6
  import requests
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
 
11
 
12
+ CHAT_API_URL = "https://router.huggingface.co/v1/chat/completions"
13
  PRIMARY_MODEL = "Qwen/Qwen2.5-72B-Instruct"
14
  FALLBACK_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
15
  STT_MODEL = "openai/whisper-base"
16
 
17
+ _inference_client = None
 
 
 
18
 
19
 
20
+ def _get_client():
21
+ """Lazy-init the HF InferenceClient."""
22
+ global _inference_client
23
+ if _inference_client is None:
24
+ from huggingface_hub import InferenceClient
25
+ _inference_client = InferenceClient(token=HF_TOKEN)
26
+ return _inference_client
27
 
28
 
29
  def chat_completion(messages, max_tokens=1024, temperature=0.7):
30
+ """Send chat completion request via the OpenAI-compatible endpoint."""
31
+ headers = {
32
+ "Authorization": f"Bearer {HF_TOKEN}",
33
+ "Content-Type": "application/json"
34
+ }
35
  payload = {
36
  "model": PRIMARY_MODEL,
37
  "messages": messages,
 
43
  for model in [PRIMARY_MODEL, FALLBACK_MODEL]:
44
  try:
45
  payload["model"] = model
46
+ resp = requests.post(CHAT_API_URL, headers=headers, json=payload, timeout=60)
 
 
 
 
 
47
  resp.raise_for_status()
48
  data = resp.json()
49
  content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
 
57
 
58
 
59
  def text_to_speech(text, tts_model="facebook/mms-tts-hin"):
60
+ """Convert text to speech audio bytes using HF InferenceClient."""
61
  if not text or not text.strip():
62
  return None
63
 
 
64
  tts_text = text[:500]
65
 
66
  try:
67
+ client = _get_client()
68
+ audio_bytes = client.text_to_speech(tts_text, model=tts_model)
69
+ if isinstance(audio_bytes, bytes) and len(audio_bytes) > 100:
70
+ return audio_bytes
 
 
 
 
 
 
71
  return None
72
  except Exception as e:
73
  logger.warning(f"TTS failed for model {tts_model}: {e}")
 
75
 
76
 
77
  def speech_to_text(audio_bytes):
78
+ """Transcribe audio to text using HF InferenceClient."""
79
  if not audio_bytes:
80
  return ""
81
 
82
  try:
83
+ client = _get_client()
84
+ result = client.automatic_speech_recognition(audio_bytes, model=STT_MODEL)
85
+ if isinstance(result, dict):
86
+ return result.get("text", "")
87
+ if hasattr(result, "text"):
88
+ return result.text
89
+ return str(result) if result else ""
 
 
 
 
 
 
 
 
 
90
  except Exception as e:
91
  logger.warning(f"STT failed: {e}")
92
  return ""