Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import torch | |
| import torchaudio | |
| import requests | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.responses import FileResponse | |
| from transformers import ( | |
| Wav2Vec2Processor, Wav2Vec2ForCTC, | |
| AutoFeatureExtractor, AutoModelForAudioClassification | |
| ) | |
| from starlette.middleware.cors import CORSMiddleware | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Device:", DEVICE) | |
| # Load config | |
| with open("config.json") as f: | |
| config = json.load(f) | |
| ELEVEN_API_KEY = config["eleven_api_key"] | |
| VOICE_ID = config["eleven_voice_id"] | |
| LLM_URL = config["llm_url"] | |
| def load_audio(audio_path, target_sr=16000): | |
| wav, sr = torchaudio.load(audio_path) | |
| if wav.shape[0] > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| if sr != target_sr: | |
| wav = torchaudio.functional.resample(wav, sr, target_sr) | |
| return wav.squeeze().numpy(), target_sr | |
| # STT MODEL | |
| print("Loading STT model...") | |
| stt_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all") | |
| stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all").to(DEVICE) | |
| stt_model.eval() | |
| print("STT loaded") | |
| def transcribe(audio_path): | |
| wav, sr = load_audio(audio_path) | |
| inputs = stt_processor(wav, sampling_rate=sr, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| logits = stt_model(inputs.input_values.to(DEVICE)).logits | |
| ids = torch.argmax(logits, dim=-1) | |
| return stt_processor.batch_decode(ids)[0].strip() | |
| # EMOTION MODEL # | |
| print("Loading Emotion model...") | |
| emotion_extractor = AutoFeatureExtractor.from_pretrained("superb/hubert-base-superb-er") | |
| emotion_model = AutoModelForAudioClassification.from_pretrained( | |
| "superb/hubert-base-superb-er" | |
| ).to(DEVICE) | |
| emotion_model.eval() | |
| print("Emotion model loaded") | |
| def get_emotion(audio_path): | |
| wav, sr = load_audio(audio_path) | |
| feats = emotion_extractor(wav, sampling_rate=sr, return_tensors="pt") | |
| with torch.no_grad(): | |
| out = emotion_model(feats["input_values"].to(DEVICE)) | |
| pred = torch.argmax(out.logits, dim=-1).item() | |
| return emotion_model.config.id2label[pred] | |
| # LLM CALL | |
| def ask_llm(text): | |
| payload = {"query": text} | |
| r = requests.post(LLM_URL, json=payload, timeout=200) | |
| try: | |
| return r.json()["answer"] | |
| except: | |
| return str(r.json()) | |
| # TTS | |
| def tts_eleven(text, out_file="response.mp3"): | |
| url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}" | |
| headers = { | |
| "xi-api-key": ELEVEN_API_KEY, | |
| "Content-Type": "application/json", | |
| } | |
| payload = {"text": text, "model_id": "eleven_multilingual_v2"} | |
| resp = requests.post(url, json=payload, headers=headers) | |
| if resp.status_code != 200: | |
| raise Exception(f"ElevenLabs API Error: {resp.text}") | |
| with open(out_file, "wb") as f: | |
| f.write(resp.content) | |
| return out_file | |
| # FASTAPI APP | |
| app = FastAPI(title="Voice AI API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def process_audio(file: UploadFile = File(...)): | |
| audio_path = f"temp_{file.filename}" | |
| with open(audio_path, "wb") as f: | |
| f.write(await file.read()) | |
| transcript = transcribe(audio_path) | |
| emotion = get_emotion(audio_path) | |
| llm_response = ask_llm(transcript) | |
| tts_file = tts_eleven(llm_response) | |
| return FileResponse(tts_file, media_type="audio/mpeg", filename="response.mp3") | |
| async def root(): | |
| return { | |
| "message": "Voice AI API is running. Use /process-audio/ to upload audio." | |
| } | |