iajitpanday commited on
Commit
b667242
·
verified ·
1 Parent(s): e2b372a

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +50 -0
utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, VitsModel
2
+ import soundfile as sf
3
+ import torch
4
+ import io
5
+ import os
6
+
7
+ # Speech-to-Text (Whisper)
8
+ def transcribe_audio(audio_path):
9
+ try:
10
+ whisper = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
11
+ audio, sample_rate = sf.read(audio_path)
12
+ if sample_rate != 8000: # Convert to 8kHz for Twilio compatibility
13
+ audio = sf.read(audio_path, samplerate=8000)[0]
14
+ sf.write(audio_path, audio, 8000)
15
+ result = whisper(audio_path)
16
+ return result["text"]
17
+ except Exception as e:
18
+ print(f"STT Error: {e}")
19
+ return "Sorry, I couldn't understand that."
20
+
21
+ # NLP (Falcon-7B-Instruct)
22
+ def generate_response(text):
23
+ try:
24
+ tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b-instruct")
25
+ model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-7b-instruct")
26
+ prompt = (
27
+ "You are a polite and helpful customer support agent. Respond professionally.\n"
28
+ f"User: {text}\nAgent:"
29
+ )
30
+ inputs = tokenizer(prompt, return_tensors="pt")
31
+ outputs = model.generate(**inputs, max_length=200, do_sample=True, top_p=0.9)
32
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
+ return response.split("Agent:")[1].strip()
34
+ except Exception as e:
35
+ print(f"NLP Error: {e}")
36
+ return "I'm having trouble processing your request. Please try again."
37
+
38
+ # Text-to-Speech (VITS)
39
+ def text_to_speech(text, output_path="output.wav"):
40
+ try:
41
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
42
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
43
+ inputs = tts_tokenizer(text, return_tensors="pt")
44
+ with torch.no_grad():
45
+ waveform = tts_model(**inputs).waveform
46
+ sf.write(output_path, waveform.squeeze().numpy(), 8000) # 8kHz for Twilio
47
+ return output_path
48
+ except Exception as e:
49
+ print(f"TTS Error: {e}")
50
+ return None