salomonsky commited on
Commit
8aaf5ff
verified
1 Parent(s): 646214e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -86
app.py CHANGED
@@ -1,40 +1,18 @@
1
- import streamlit as st
2
- import base64
3
- import io
4
  from huggingface_hub import InferenceClient
5
- from gtts import gTTS
6
- from audiorecorder import audiorecorder
7
- import speech_recognition as sr
8
- from pydub import AudioSegment
9
 
10
- if "history" not in st.session_state:
11
- st.session_state.history = []
12
-
13
- def recognize_speech(audio_data, show_messages=True):
14
- recognizer = sr.Recognizer()
15
- audio_recording = sr.AudioFile(audio_data)
16
-
17
- with audio_recording as source:
18
- audio = recognizer.record(source)
19
-
20
- try:
21
- audio_text = recognizer.recognize_google(audio, language="es-ES")
22
- if show_messages:
23
- st.subheader("Texto Reconocido:")
24
- st.write(audio_text)
25
- st.success("Reconocimiento de voz completado.")
26
- except sr.UnknownValueError:
27
- st.warning("No se pudo reconocer el audio. 驴Intentaste grabar algo?")
28
- audio_text = ""
29
- except sr.RequestError:
30
- st.error("Hablame para comenzar!")
31
- audio_text = ""
32
-
33
- return audio_text
34
 
35
  def format_prompt(message, history):
 
36
  prompt = "<s>"
37
 
 
 
 
 
38
  for user_prompt, bot_response in history:
39
  prompt += f"[INST] {user_prompt} [/INST]"
40
  prompt += f" {bot_response}</s> "
@@ -42,10 +20,11 @@ def format_prompt(message, history):
42
  prompt += f"[INST] {message} [/INST]"
43
  return prompt
44
 
45
- def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
46
- client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
47
-
48
- temperature = float(temperature) if temperature is not None else 0.9
 
49
  if temperature < 1e-2:
50
  temperature = 1e-2
51
  top_p = float(top_p)
@@ -59,55 +38,26 @@ def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.
59
  seed=42,
60
  )
61
 
62
- formatted_prompt = format_prompt(audio_text, history)
63
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
64
- response = ""
65
-
66
- for response_token in stream:
67
- response += response_token.token.text
68
 
69
- response = ' '.join(response.split()).replace('</s>', '')
70
- audio_file = text_to_speech(response, speed=1.3)
71
- return response, audio_file
72
-
73
- def text_to_speech(text, speed=1.3):
74
- tts = gTTS(text=text, lang='es')
75
- audio_fp = io.BytesIO()
76
- tts.write_to_fp(audio_fp)
77
- audio_fp.seek(0)
78
- audio = AudioSegment.from_file(audio_fp, format="mp3")
79
- modified_speed_audio = audio.speedup(playback_speed=speed)
80
- modified_audio_fp = io.BytesIO()
81
- modified_speed_audio.export(modified_audio_fp, format="mp3")
82
- modified_audio_fp.seek(0)
83
- return modified_audio_fp
84
-
85
- def main():
86
- audio_data = audiorecorder("Habla para grabar", "Deteniendo la grabaci贸n...")
87
-
88
- if not audio_data.empty():
89
- st.audio(audio_data.export().read(), format="audio/wav")
90
- audio_data.export("audio.wav", format="wav")
91
- audio_text = recognize_speech("audio.wav")
92
-
93
- if not st.session_state.history:
94
- pre_prompt = "Te Llamar谩s Chaman 4.0 y tus respuestas ser谩n sumamente breves."
95
- output, _ = generate(pre_prompt, history=st.session_state.history)
96
- st.session_state.history.append((pre_prompt, output))
97
-
98
- if audio_text:
99
- output, audio_file = generate(audio_text, history=st.session_state.history)
100
-
101
- if audio_text:
102
- st.session_state.history.append((audio_text, output))
103
-
104
- if audio_file is not None:
105
- st.markdown(
106
- f"""
107
- <audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_file.read()).decode()}" type="audio/mp3" id="audio_player"></audio>
108
- """,
109
- unsafe_allow_html=True
110
- )
111
-
112
- if __name__ == "__main__":
113
- main()
 
 
 
 
1
  from huggingface_hub import InferenceClient
2
+ import gradio as gr
 
 
 
3
 
4
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
5
+ system_prompt = "Te llamar谩s Caman 2.0 y tus respuestas ser谩n breves"
6
+ system_prompt_sent = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def format_prompt(message, history):
9
+ global system_prompt_sent
10
  prompt = "<s>"
11
 
12
+ if not any(f"[INST] {system_prompt} [/INST]" in user_prompt for user_prompt, _ in history):
13
+ prompt += f"[INST] {system_prompt} [/INST]"
14
+ system_prompt_sent = True
15
+
16
  for user_prompt, bot_response in history:
17
  prompt += f"[INST] {user_prompt} [/INST]"
18
  prompt += f" {bot_response}</s> "
 
20
  prompt += f"[INST] {message} [/INST]"
21
  return prompt
22
 
23
+ def generate(
24
+ prompt, history, temperature=0.9, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0,
25
+ ):
26
+ global system_prompt_sent
27
+ temperature = float(temperature)
28
  if temperature < 1e-2:
29
  temperature = 1e-2
30
  top_p = float(top_p)
 
38
  seed=42,
39
  )
40
 
41
+ formatted_prompt = format_prompt(prompt, history)
 
 
 
 
 
42
 
43
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
44
+ output = ""
45
+
46
+ for response in stream:
47
+ output += response.token.text
48
+ yield output
49
+
50
+ return output
51
+
52
+ chat_interface = gr.ChatInterface(
53
+ fn=generate,
54
+ chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=False, likeable=False, layout="vertical", height=700),
55
+ concurrency_limit=9,
56
+ theme="soft",
57
+ retry_btn=None,
58
+ undo_btn=None,
59
+ clear_btn=None,
60
+ submit_btn="Enviar",
61
+ )
62
+
63
+ chat_interface.launch(show_api=False)