Spaces:
Running
Running
| """HuggingFace Inference API wrapper for LLM, TTS, and STT.""" | |
| import os | |
| import io | |
| import logging | |
| import requests | |
| logger = logging.getLogger(__name__) | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| CHAT_API_URL = "https://router.huggingface.co/v1/chat/completions" | |
| PRIMARY_MODEL = "Qwen/Qwen2.5-72B-Instruct" | |
| FALLBACK_MODEL = "meta-llama/Llama-3.2-3B-Instruct" | |
| STT_MODEL = "openai/whisper-base" | |
| _inference_client = None | |
| def _get_client(): | |
| """Lazy-init the HF InferenceClient.""" | |
| global _inference_client | |
| if _inference_client is None: | |
| from huggingface_hub import InferenceClient | |
| _inference_client = InferenceClient(token=HF_TOKEN) | |
| return _inference_client | |
| def chat_completion(messages, max_tokens=1024, temperature=0.7): | |
| """Send chat completion request via the OpenAI-compatible endpoint.""" | |
| headers = { | |
| "Authorization": f"Bearer {HF_TOKEN}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": PRIMARY_MODEL, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "stream": False | |
| } | |
| for model in [PRIMARY_MODEL, FALLBACK_MODEL]: | |
| try: | |
| payload["model"] = model | |
| resp = requests.post(CHAT_API_URL, headers=headers, json=payload, timeout=60) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| content = data.get("choices", [{}])[0].get("message", {}).get("content", "") | |
| if content: | |
| return {"content": content, "model": model} | |
| except Exception as e: | |
| logger.warning(f"Model {model} failed: {e}") | |
| continue | |
| return {"content": "I'm sorry, I'm having trouble connecting right now. Please try again.", "model": "fallback"} | |
| def text_to_speech(text, tts_model="facebook/mms-tts-hin"): | |
| """Convert text to speech audio bytes using HF InferenceClient.""" | |
| if not text or not text.strip(): | |
| return None | |
| tts_text = text[:500] | |
| try: | |
| client = _get_client() | |
| audio_bytes = client.text_to_speech(tts_text, model=tts_model) | |
| if isinstance(audio_bytes, bytes) and len(audio_bytes) > 100: | |
| return audio_bytes | |
| return None | |
| except Exception as e: | |
| logger.warning(f"TTS failed for model {tts_model}: {e}") | |
| return None | |
| def speech_to_text(audio_bytes): | |
| """Transcribe audio to text using HF InferenceClient.""" | |
| if not audio_bytes: | |
| return "" | |
| try: | |
| client = _get_client() | |
| result = client.automatic_speech_recognition(audio_bytes, model=STT_MODEL) | |
| if isinstance(result, dict): | |
| return result.get("text", "") | |
| if hasattr(result, "text"): | |
| return result.text | |
| return str(result) if result else "" | |
| except Exception as e: | |
| logger.warning(f"STT failed: {e}") | |
| return "" | |
| def get_model_info(): | |
| """Return info about the models being used.""" | |
| return { | |
| "llm": PRIMARY_MODEL, | |
| "llm_fallback": FALLBACK_MODEL, | |
| "stt": STT_MODEL, | |
| "tts": "facebook/mms-tts-*" | |
| } | |