Letsch22 commited on
Commit
68c9c6f
·
1 Parent(s): 1fde471

Basic audio input integration to chatbot

Browse files
Files changed (1) hide show
  1. app.py +36 -40
app.py CHANGED
@@ -4,7 +4,7 @@ from time import sleep
4
  from typing import Dict, List, Generator
5
 
6
  import gradio as gr
7
- import openai
8
  from dotenv import load_dotenv
9
 
10
  load_dotenv()
@@ -12,19 +12,40 @@ load_dotenv()
12
  class MockInterviewer:
13
 
14
  def __init__(self) -> None:
15
- self._client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'])
16
  self._assistant_id_cache: Dict[str, str] = {}
17
  self.clear_thread()
18
 
19
- def chat(self, usr_message: Dict, history: List[List], job_role: str, company: str) -> Generator:
20
- print('Started chat')
21
- self._validate_fields(job_role, company)
22
- assistant_id = self._init_assistant(job_role, company)
23
- yield self._send_message(usr_message.get('text'), assistant_id)
24
 
25
  def clear_thread(self) -> None:
26
  print('Initializing new thread')
27
  self._thread = self._client.beta.threads.create()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def _send_message(self, message: str, assistant_id: str) -> str:
30
  self._client.beta.threads.messages.create(thread_id=self._thread.id, role='user', content=message)
@@ -47,14 +68,6 @@ class MockInterviewer:
47
  print(f'Assistant response: {response}')
48
  return response
49
 
50
- def _validate_fields(self, job_role: str, company: str) -> None:
51
- if not job_role and not company:
52
- raise gr.Error('Job Role and Company are required fields.')
53
- if not job_role:
54
- raise gr.Error('Job Role is a required field.')
55
- if not company:
56
- raise gr.Error('Company is a required field.')
57
-
58
  def _create_files(self, company: str) -> List[str]:
59
  if company.lower() == 'amazon':
60
  url = 'https://www.aboutamazon.com/about-us/leadership-principles'
@@ -95,20 +108,6 @@ class MockInterviewer:
95
  def _create_cache_key(self, job_role: str, company: str) -> str:
96
  return f'{job_role.lower()}+{company.lower()}'
97
 
98
- def transcript(audio):
99
- try:
100
- print(audio)
101
- audio_file = open(audio, "rb")
102
- transcriptions = openai.audio.transcriptions.create(
103
- model="whisper-1",
104
- file=audio_file,
105
- )
106
- except Exception as error:
107
- print(str(error))
108
- raise gr.Error("An error occurred while generating speech. Please check your API key and come back try again.")
109
-
110
- return transcriptions.text
111
-
112
  # Creating the Gradio interface
113
  with gr.Blocks() as demo:
114
  mock_interviewer = MockInterviewer()
@@ -116,27 +115,24 @@ with gr.Blocks() as demo:
116
  with gr.Row():
117
  job_role = gr.Textbox(label='Job Role', placeholder='Product Manager')
118
  company = gr.Textbox(label='Company', placeholder='Amazon')
119
- audio = gr.Audio(sources=["microphone"], type="filepath")
120
-
121
- submit_btn = gr.Button("Submit")
122
-
123
- response_output = gr.Textbox(label="Interviewer Response")
124
- stt_output = gr.Textbox(label="Speech-To-Text Transcription")
125
-
126
 
127
  chat_interface = gr.ChatInterface(
128
- fn=lambda usr_message, history, job_role, company: mock_interviewer.chat(usr_message, history, job_role, company),
129
  additional_inputs=[job_role, company],
130
  title='I am your AI mock interviewer',
131
  description='Make your selections above to configure me.',
132
  multimodal=True,
133
  retry_btn=None,
134
- undo_btn=None
135
- ).queue()
136
 
137
  chat_interface.load(mock_interviewer.clear_thread)
138
  chat_interface.clear_btn.click(mock_interviewer.clear_thread)
139
- audio.stop_recording(fn=MockInterviewer.transcript, inputs=[audio], outputs=stt_output, api_name=False)
 
 
 
 
 
140
 
141
  if __name__ == '__main__':
142
  demo.launch().queue()
 
4
  from typing import Dict, List, Generator
5
 
6
  import gradio as gr
7
+ from openai import OpenAI
8
  from dotenv import load_dotenv
9
 
10
  load_dotenv()
 
12
  class MockInterviewer:
13
 
14
  def __init__(self) -> None:
15
+ self._client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
16
  self._assistant_id_cache: Dict[str, str] = {}
17
  self.clear_thread()
18
 
19
+ def interface_chat(self, message: Dict, history: List[List], job_role: str, company: str) -> Generator:
20
+ yield self._chat(message.get('text'), job_role, company)
 
 
 
21
 
22
  def clear_thread(self) -> None:
23
  print('Initializing new thread')
24
  self._thread = self._client.beta.threads.create()
25
+
26
+ def transcript(self, audio: str, job_role: str, company: str) -> str:
27
+ with open(audio, 'rb') as audio_file:
28
+ transcriptions = self._client.audio.transcriptions.create(
29
+ model='whisper-1',
30
+ file=audio_file,
31
+ )
32
+ os.remove(audio)
33
+ response = self._chat(transcriptions.text, job_role, company)
34
+ return [(transcriptions.text, response)]
35
+
36
+ def _chat(self, message: str, job_role: str, company: str) -> str:
37
+ print('Started chat')
38
+ self._validate_fields(job_role, company)
39
+ assistant_id = self._init_assistant(job_role, company)
40
+ return self._send_message(message, assistant_id)
41
+
42
+ def _validate_fields(self, job_role: str, company: str) -> None:
43
+ if not job_role and not company:
44
+ raise gr.Error('Job Role and Company are required fields.')
45
+ if not job_role:
46
+ raise gr.Error('Job Role is a required field.')
47
+ if not company:
48
+ raise gr.Error('Company is a required field.')
49
 
50
  def _send_message(self, message: str, assistant_id: str) -> str:
51
  self._client.beta.threads.messages.create(thread_id=self._thread.id, role='user', content=message)
 
68
  print(f'Assistant response: {response}')
69
  return response
70
 
 
 
 
 
 
 
 
 
71
  def _create_files(self, company: str) -> List[str]:
72
  if company.lower() == 'amazon':
73
  url = 'https://www.aboutamazon.com/about-us/leadership-principles'
 
108
  def _create_cache_key(self, job_role: str, company: str) -> str:
109
  return f'{job_role.lower()}+{company.lower()}'
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # Creating the Gradio interface
112
  with gr.Blocks() as demo:
113
  mock_interviewer = MockInterviewer()
 
115
  with gr.Row():
116
  job_role = gr.Textbox(label='Job Role', placeholder='Product Manager')
117
  company = gr.Textbox(label='Company', placeholder='Amazon')
 
 
 
 
 
 
 
118
 
119
  chat_interface = gr.ChatInterface(
120
+ fn=mock_interviewer.interface_chat,
121
  additional_inputs=[job_role, company],
122
  title='I am your AI mock interviewer',
123
  description='Make your selections above to configure me.',
124
  multimodal=True,
125
  retry_btn=None,
126
+ undo_btn=None).queue()
 
127
 
128
  chat_interface.load(mock_interviewer.clear_thread)
129
  chat_interface.clear_btn.click(mock_interviewer.clear_thread)
130
+
131
+ audio = gr.Audio(sources=['microphone'], type='filepath', editable=False)
132
+ audio.stop_recording(fn=mock_interviewer.transcript,
133
+ inputs=[audio, job_role, company],
134
+ outputs=[chat_interface.chatbot],
135
+ api_name=False)
136
 
137
  if __name__ == '__main__':
138
  demo.launch().queue()