drrobot9 commited on
Commit
a988a6f
·
verified ·
1 Parent(s): 3845214

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +36 -32
main.py CHANGED
@@ -1,9 +1,8 @@
1
  import os
2
  import json
3
  import torch
4
- import librosa
5
  import requests
6
- import soundfile as sf
7
  from fastapi import FastAPI, UploadFile, File
8
  from fastapi.responses import FileResponse
9
  from transformers import (
@@ -24,82 +23,88 @@ VOICE_ID = config["eleven_voice_id"]
24
  LLM_URL = config["llm_url"]
25
 
26
 
27
- # STT Model
28
-
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  print("Loading STT model...")
30
  stt_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all")
31
  stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all").to(DEVICE)
32
  stt_model.eval()
33
- print("STT loaded ")
34
-
35
 
36
  def transcribe(audio_path):
37
- wav, sr = librosa.load(audio_path, sr=16000)
38
- inputs = stt_processor(wav, sampling_rate=16000, return_tensors="pt", padding=True)
39
  with torch.no_grad():
40
  logits = stt_model(inputs.input_values.to(DEVICE)).logits
41
  ids = torch.argmax(logits, dim=-1)
42
  return stt_processor.batch_decode(ids)[0].strip()
43
 
44
 
45
- # Emotion Model
46
-
47
  print("Loading Emotion model...")
48
  emotion_extractor = AutoFeatureExtractor.from_pretrained("superb/hubert-base-superb-er")
49
  emotion_model = AutoModelForAudioClassification.from_pretrained(
50
  "superb/hubert-base-superb-er"
51
  ).to(DEVICE)
52
  emotion_model.eval()
53
- print("Emotion model loaded ")
54
-
55
 
56
  def get_emotion(audio_path):
57
- wav, sr = librosa.load(audio_path, sr=16000)
58
- feats = emotion_extractor(wav, sampling_rate=16000, return_tensors="pt")
59
  with torch.no_grad():
60
  out = emotion_model(feats["input_values"].to(DEVICE))
61
  pred = torch.argmax(out.logits, dim=-1).item()
62
  return emotion_model.config.id2label[pred]
63
 
64
 
65
-
66
- # LLM Call
67
-
68
  def ask_llm(text):
69
  payload = {"query": text}
70
  r = requests.post(LLM_URL, json=payload, timeout=200)
 
71
  try:
72
  return r.json()["answer"]
73
  except:
74
  return str(r.json())
75
 
76
 
77
-
78
- # TTS
79
-
80
  def tts_eleven(text, out_file="response.mp3"):
81
  url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}"
82
  headers = {
83
  "xi-api-key": ELEVEN_API_KEY,
84
- "Content-Type": "application/json"
85
  }
86
  payload = {"text": text, "model_id": "eleven_multilingual_v2"}
87
 
88
  resp = requests.post(url, json=payload, headers=headers)
89
  if resp.status_code != 200:
90
- raise Exception(f"ElevenLabs TTS Error: {resp.text}")
91
 
92
  with open(out_file, "wb") as f:
93
  f.write(resp.content)
94
- return out_file
95
-
96
 
 
97
 
98
- # FastAPI App
99
 
 
100
  app = FastAPI(title="Voice AI API")
101
 
102
- # Enable CORS for Hugging Face Spaces frontend
103
  app.add_middleware(
104
  CORSMiddleware,
105
  allow_origins=["*"],
@@ -108,23 +113,22 @@ app.add_middleware(
108
  allow_headers=["*"],
109
  )
110
 
111
-
112
  @app.post("/process-audio/")
113
  async def process_audio(file: UploadFile = File(...)):
114
  audio_path = f"temp_{file.filename}"
115
  with open(audio_path, "wb") as f:
116
  f.write(await file.read())
117
 
118
-
119
  transcript = transcribe(audio_path)
120
  emotion = get_emotion(audio_path)
121
- llm_out = ask_llm(transcript)
122
- tts_file = tts_eleven(llm_out)
123
 
124
- # Return TTS file as downloadable mp3
125
  return FileResponse(tts_file, media_type="audio/mpeg", filename="response.mp3")
126
 
127
 
128
  @app.get("/")
129
  async def root():
130
- return {"message": "Voice AI API is running. Use /process-audio/ endpoint to upload audio."}
 
 
 
1
  import os
2
  import json
3
  import torch
4
+ import torchaudio
5
  import requests
 
6
  from fastapi import FastAPI, UploadFile, File
7
  from fastapi.responses import FileResponse
8
  from transformers import (
 
23
  LLM_URL = config["llm_url"]
24
 
25
 
26
+
27
+ def load_audio(audio_path, target_sr=16000):
28
+ wav, sr = torchaudio.load(audio_path)
29
+
30
+ if wav.shape[0] > 1:
31
+ wav = wav.mean(dim=0, keepdim=True)
32
+
33
+
34
+ if sr != target_sr:
35
+ wav = torchaudio.functional.resample(wav, sr, target_sr)
36
+
37
+ return wav.squeeze().numpy(), target_sr
38
+
39
+
40
+ # STT MODEL
41
  print("Loading STT model...")
42
  stt_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all")
43
  stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all").to(DEVICE)
44
  stt_model.eval()
45
+ print("STT loaded")
 
46
 
47
  def transcribe(audio_path):
48
+ wav, sr = load_audio(audio_path)
49
+ inputs = stt_processor(wav, sampling_rate=sr, return_tensors="pt", padding=True)
50
  with torch.no_grad():
51
  logits = stt_model(inputs.input_values.to(DEVICE)).logits
52
  ids = torch.argmax(logits, dim=-1)
53
  return stt_processor.batch_decode(ids)[0].strip()
54
 
55
 
56
+ # EMOTION MODEL #
 
57
  print("Loading Emotion model...")
58
  emotion_extractor = AutoFeatureExtractor.from_pretrained("superb/hubert-base-superb-er")
59
  emotion_model = AutoModelForAudioClassification.from_pretrained(
60
  "superb/hubert-base-superb-er"
61
  ).to(DEVICE)
62
  emotion_model.eval()
63
+ print("Emotion model loaded")
 
64
 
65
  def get_emotion(audio_path):
66
+ wav, sr = load_audio(audio_path)
67
+ feats = emotion_extractor(wav, sampling_rate=sr, return_tensors="pt")
68
  with torch.no_grad():
69
  out = emotion_model(feats["input_values"].to(DEVICE))
70
  pred = torch.argmax(out.logits, dim=-1).item()
71
  return emotion_model.config.id2label[pred]
72
 
73
 
74
+ # LLM CALL
 
 
75
  def ask_llm(text):
76
  payload = {"query": text}
77
  r = requests.post(LLM_URL, json=payload, timeout=200)
78
+
79
  try:
80
  return r.json()["answer"]
81
  except:
82
  return str(r.json())
83
 
84
 
85
+ # TTS
 
 
86
  def tts_eleven(text, out_file="response.mp3"):
87
  url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}"
88
  headers = {
89
  "xi-api-key": ELEVEN_API_KEY,
90
+ "Content-Type": "application/json",
91
  }
92
  payload = {"text": text, "model_id": "eleven_multilingual_v2"}
93
 
94
  resp = requests.post(url, json=payload, headers=headers)
95
  if resp.status_code != 200:
96
+ raise Exception(f"ElevenLabs API Error: {resp.text}")
97
 
98
  with open(out_file, "wb") as f:
99
  f.write(resp.content)
 
 
100
 
101
+ return out_file
102
 
 
103
 
104
+ # FASTAPI APP
105
  app = FastAPI(title="Voice AI API")
106
 
107
+
108
  app.add_middleware(
109
  CORSMiddleware,
110
  allow_origins=["*"],
 
113
  allow_headers=["*"],
114
  )
115
 
 
116
  @app.post("/process-audio/")
117
  async def process_audio(file: UploadFile = File(...)):
118
  audio_path = f"temp_{file.filename}"
119
  with open(audio_path, "wb") as f:
120
  f.write(await file.read())
121
 
 
122
  transcript = transcribe(audio_path)
123
  emotion = get_emotion(audio_path)
124
+ llm_response = ask_llm(transcript)
125
+ tts_file = tts_eleven(llm_response)
126
 
 
127
  return FileResponse(tts_file, media_type="audio/mpeg", filename="response.mp3")
128
 
129
 
130
  @app.get("/")
131
  async def root():
132
+ return {
133
+ "message": "Voice AI API is running. Use /process-audio/ to upload audio."
134
+ }