kathirog commited on
Commit
f404bdb
·
verified ·
1 Parent(s): ab37b41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -78
app.py CHANGED
@@ -4,83 +4,98 @@ from threading import Thread
4
 
5
  import gradio as gr
6
  import torch
7
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
8
- from accelerate import Accelerator # Use Accelerate for better performance
9
-
10
- # Argument parsing (optional, can be omitted if not required)
11
- parser = argparse.ArgumentParser(prog="SOCRATIC-CHATBOT", description="Socratic chatbot")
12
-
13
- parser.add_argument("--load-in-4bit",
14
- action="store_true",
15
- help="Load base model with 4bit quantization (requires GPU)")
16
-
17
- parser.add_argument("--server-port",
18
- type=int,
19
- default=2121,
20
- help="The port the chatbot server listens to")
21
-
22
- args = parser.parse_args()
23
-
24
- # Accelerator setup to manage devices efficiently (CPU/GPU)
25
- accelerator = Accelerator()
26
-
27
- with gr.Blocks() as demo:
28
- chatbot = gr.Chatbot()
29
- msg = gr.Textbox()
30
- clear = gr.Button("Clear")
31
-
32
- # Load prompt template from external file
33
- with urllib.request.urlopen(
34
- "https://raw.githubusercontent.com/GiovanniGatti/socratic-llm/kdd-2024/templates/inference.txt"
35
- ) as f:
36
- inference_prompt_template = f.read().decode('utf-8')
37
-
38
- # Detect device (GPU if available)
39
- device = accelerator.device
40
-
41
- # Load model and tokenizer with efficient memory management
42
- model = AutoModelForCausalLM.from_pretrained(
43
- "eurecom-ds/Phi-3-mini-4k-socratic",
44
- torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32, # Use bfloat16 on GPU
45
- load_in_4bit=args.load_in_4bit,
46
- trust_remote_code=True,
47
- device_map="auto", # Automatically distribute model to available devices
48
- ).to(device)
49
-
50
- tokenizer = AutoTokenizer.from_pretrained("eurecom-ds/Phi-3-mini-4k-socratic")
51
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
52
-
53
- # Function to handle user messages
54
- def user(user_message, history):
55
- return "", history + [[user_message, ""]]
56
-
57
- # Function to generate bot responses
58
- def bot(history):
59
- user_query = "".join(f"Student: {s}\nTeacher: {t}\n" for s, t in history[:-1])
60
- last_query = history[-1][0]
61
- user_query += f"Student: {last_query}"
62
- content = inference_prompt_template.format(input=user_query)
63
-
64
- formatted = tokenizer.apply_chat_template(
65
- [{"role": "user", "content": content}], tokenize=False, add_generation_prompt=True
66
  )
67
 
68
- encoded_inputs = tokenizer([formatted], return_tensors="pt").to(device)
69
-
70
- # Use threads to handle model generation asynchronously
71
- thread = Thread(target=model.generate, kwargs=dict(encoded_inputs, max_new_tokens=250, streamer=streamer))
72
- thread.start()
73
-
74
- for word in streamer:
75
- history[-1][1] += word
76
- yield history
77
-
78
- # User interaction handling
79
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, [chatbot], chatbot)
80
-
81
- # Clear chat button functionality
82
- clear.click(lambda: None, None, chatbot, queue=False)
83
-
84
- # Launch the Gradio app
85
- demo.queue()
86
- demo.launch(server_name="0.0.0.0", server_port=args.server_port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import gradio as gr
6
  import torch
7
+ import pyttsx3
8
+ import speech_recognition as sr
9
+ from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM
10
+
11
+ # Convert voice input (audio) to text
12
+ def voice_to_text(audio):
13
+ recognizer = sr.Recognizer()
14
+ with sr.AudioFile(audio.name) as source:
15
+ audio_data = recognizer.record(source)
16
+ try:
17
+ text = recognizer.recognize_google(audio_data) # Convert to text using Google's speech recognition
18
+ except sr.UnknownValueError:
19
+ text = "Sorry, I could not understand the audio."
20
+ except sr.RequestError:
21
+ text = "Could not request results from Google Speech Recognition service."
22
+ return text
23
+
24
+ # Convert text to speech (voice output)
25
+ def text_to_voice(text):
26
+ engine = pyttsx3.init()
27
+ engine.save_to_file(text, 'response.mp3')
28
+ engine.runAndWait()
29
+ return 'response.mp3'
30
+
31
+ # Model loading and configuration
32
+ if __name__ == '__main__':
33
+ parser = argparse.ArgumentParser(prog="SOCRATIC-CHATBOT", description="Socratic chatbot")
34
+
35
+ parser.add_argument("--load-in-4bit",
36
+ action="store_true",
37
+ help="Load base model with 4bit quantization (requires GPU)")
38
+
39
+ parser.add_argument("--server-port",
40
+ type=int,
41
+ default=2121,
42
+ help="The port the chatbot server listens to")
43
+
44
+ args = parser.parse_args()
45
+
46
+ with gr.Blocks() as demo:
47
+ chatbot = gr.Chatbot()
48
+ msg = gr.Textbox()
49
+ audio_input = gr.Audio(type="filepath", label="Audio Input (or leave blank to use text input)")
50
+ clear = gr.Button("Clear")
51
+
52
+ with urllib.request.urlopen(
53
+ "https://raw.githubusercontent.com/GiovanniGatti/socratic-llm/kdd-2024/templates/inference.txt"
54
+ ) as f:
55
+ inference_prompt_template = f.read().decode('utf-8')
56
+
57
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
58
+
59
+ model = AutoModelForCausalLM.from_pretrained(
60
+ "eurecom-ds/Phi-3-mini-4k-socratic",
61
+ torch_dtype=torch.bfloat16,
62
+ load_in_4bit=args.load_in_4bit,
63
+ trust_remote_code=True,
64
+ device_map=device,
 
65
  )
66
 
67
+ tokenizer = AutoTokenizer.from_pretrained("eurecom-ds/Phi-3-mini-4k-socratic")
68
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
69
+
70
+ def user(user_message, history):
71
+ return "", history + [[user_message, ""]]
72
+
73
+ def bot(history, audio=None):
74
+ user_query = ""
75
+ if audio:
76
+ # Convert audio to text
77
+ user_query = voice_to_text(audio)
78
+ else:
79
+ user_query = "".join(f"Student: {s}\nTeacher: {t}\n" for s, t in history[:-1])
80
+ last_query: str = history[-1][0]
81
+ user_query += f"Student: {last_query}"
82
+
83
+ content = inference_prompt_template.format(input=user_query)
84
+ formatted = tokenizer.apply_chat_template(
85
+ [{"role": "user", "content": content}], tokenize=False, add_generation_prompt=True
86
+ )
87
+
88
+ encoded_inputs = tokenizer([formatted], return_tensors="pt").to(device)
89
+
90
+ thread = Thread(target=model.generate, kwargs=dict(encoded_inputs, max_new_tokens=250, streamer=streamer))
91
+ thread.start()
92
+
93
+ for word in streamer:
94
+ history[-1][1] += word
95
+ yield history, text_to_voice(history[-1][1])
96
+
97
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, [chatbot, audio_input], [chatbot, gr.Audio()])
98
+ clear.click(lambda: None, None, chatbot, queue=False)
99
+
100
+ demo.queue()
101
+ demo.launch(server_name="0.0.0.0", server_port=args.server_port)