voice_analysis / main.py
drrobot9's picture
Update main.py
a988a6f verified
import os
import json
import torch
import torchaudio
import requests
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import FileResponse
from transformers import (
Wav2Vec2Processor, Wav2Vec2ForCTC,
AutoFeatureExtractor, AutoModelForAudioClassification
)
from starlette.middleware.cors import CORSMiddleware
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
# Load config
with open("config.json") as f:
config = json.load(f)
ELEVEN_API_KEY = config["eleven_api_key"]
VOICE_ID = config["eleven_voice_id"]
LLM_URL = config["llm_url"]
def load_audio(audio_path, target_sr=16000):
wav, sr = torchaudio.load(audio_path)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != target_sr:
wav = torchaudio.functional.resample(wav, sr, target_sr)
return wav.squeeze().numpy(), target_sr
# STT MODEL
print("Loading STT model...")
stt_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all")
stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all").to(DEVICE)
stt_model.eval()
print("STT loaded")
def transcribe(audio_path):
wav, sr = load_audio(audio_path)
inputs = stt_processor(wav, sampling_rate=sr, return_tensors="pt", padding=True)
with torch.no_grad():
logits = stt_model(inputs.input_values.to(DEVICE)).logits
ids = torch.argmax(logits, dim=-1)
return stt_processor.batch_decode(ids)[0].strip()
# EMOTION MODEL #
print("Loading Emotion model...")
emotion_extractor = AutoFeatureExtractor.from_pretrained("superb/hubert-base-superb-er")
emotion_model = AutoModelForAudioClassification.from_pretrained(
"superb/hubert-base-superb-er"
).to(DEVICE)
emotion_model.eval()
print("Emotion model loaded")
def get_emotion(audio_path):
wav, sr = load_audio(audio_path)
feats = emotion_extractor(wav, sampling_rate=sr, return_tensors="pt")
with torch.no_grad():
out = emotion_model(feats["input_values"].to(DEVICE))
pred = torch.argmax(out.logits, dim=-1).item()
return emotion_model.config.id2label[pred]
# LLM CALL
def ask_llm(text):
payload = {"query": text}
r = requests.post(LLM_URL, json=payload, timeout=200)
try:
return r.json()["answer"]
except:
return str(r.json())
# TTS
def tts_eleven(text, out_file="response.mp3"):
url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}"
headers = {
"xi-api-key": ELEVEN_API_KEY,
"Content-Type": "application/json",
}
payload = {"text": text, "model_id": "eleven_multilingual_v2"}
resp = requests.post(url, json=payload, headers=headers)
if resp.status_code != 200:
raise Exception(f"ElevenLabs API Error: {resp.text}")
with open(out_file, "wb") as f:
f.write(resp.content)
return out_file
# FASTAPI APP
app = FastAPI(title="Voice AI API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/process-audio/")
async def process_audio(file: UploadFile = File(...)):
audio_path = f"temp_{file.filename}"
with open(audio_path, "wb") as f:
f.write(await file.read())
transcript = transcribe(audio_path)
emotion = get_emotion(audio_path)
llm_response = ask_llm(transcript)
tts_file = tts_eleven(llm_response)
return FileResponse(tts_file, media_type="audio/mpeg", filename="response.mp3")
@app.get("/")
async def root():
return {
"message": "Voice AI API is running. Use /process-audio/ to upload audio."
}