VuAI commited on
Commit
f655cc3
·
verified ·
1 Parent(s): bb38b2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -201
app.py CHANGED
@@ -1,211 +1,87 @@
1
- ##########################################
2
- # Step 0: Import required libraries
3
- ##########################################
4
- import streamlit as st # For web interface
5
  from transformers import (
6
- pipeline, # For loading pre-trained models
7
- SpeechT5Processor, # For text-to-speech processing
8
- SpeechT5ForTextToSpeech, # TTS model
9
- SpeechT5HifiGan, # Vocoder for generating audio waveforms
10
- AutoModelForCausalLM, # For text generation
11
- AutoTokenizer # For tokenizing input text
12
- ) # AI model components
13
-
14
- from datasets import load_dataset # To load voice embeddings
15
- import torch # For tensor computations
16
- import soundfile as sf # For handling audio files
17
- import re # For regular expressions in text processing
18
-
19
- ##########################################
20
- # Initial configuration
21
- ##########################################
22
- st.set_page_config(
23
- page_title="Just Comment", # Title of the web app
24
- page_icon="💬", # Icon displayed in the browser tab
25
- layout="centered", # Center the layout of the app
26
- initial_sidebar_state="collapsed" # Start with sidebar collapsed
27
  )
 
28
 
29
- ##########################################
30
- # Global model loading with caching
31
- ##########################################
32
- @st.cache_resource(show_spinner=False) # Cache the models for performance
33
- def _load_models():
34
- """Load and cache all ML models with optimized settings"""
35
- return {
36
- # Emotion classification pipeline
37
- 'emotion': pipeline(
38
- "text-classification", # Specify task type
39
- model="Thea231/jhartmann_emotion_finetuning", # Load the model
40
- truncation=True # Enable text truncation for long inputs
41
- ),
42
-
43
- # Text generation components
44
- 'textgen_tokenizer': AutoTokenizer.from_pretrained(
45
- "Qwen/Qwen1.5-0.5B", # Load tokenizer
46
- use_fast=True # Enable fast tokenization
47
- ),
48
- 'textgen_model': AutoModelForCausalLM.from_pretrained(
49
- "Qwen/Qwen1.5-0.5B", # Load text generation model
50
- torch_dtype=torch.float16 # Use half-precision for faster inference
51
- ),
52
-
53
- # Text-to-speech components
54
- 'tts_processor': SpeechT5Processor.from_pretrained("microsoft/speecht5_tts"), # Load TTS processor
55
- 'tts_model': SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts"), # Load TTS model
56
- 'tts_vocoder': SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan"), # Load vocoder
57
-
58
- # Preloaded speaker embeddings
59
- 'speaker_embeddings': torch.tensor(
60
- load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"] # Load speaker embeddings
61
- ).unsqueeze(0) # Add an additional dimension for batch processing
62
- }
63
-
64
- ##########################################
65
- # UI Components
66
- ##########################################
67
- def _display_interface():
68
- """Render user interface elements"""
69
- st.title("Just Comment") # Set the main title of the app
70
- st.markdown("### I'm listening to you, my friend~") # Subheading for user interaction
71
-
72
- return st.text_area(
73
- "📝 Enter your comment:", # Label for the text area
74
- placeholder="Type your message here...", # Placeholder text
75
- height=150, # Height of the text area
76
- key="user_input" # Unique key for the text area
77
- )
78
-
79
- ##########################################
80
- # Core Processing Functions
81
- ##########################################
82
- def _analyze_emotion(text, classifier):
83
- """Identify dominant emotion with confidence threshold"""
84
- results = classifier(text, return_all_scores=True)[0] # Get emotion scores
85
- valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'} # Define valid emotions
86
- filtered = [e for e in results if e['label'].lower() in valid_emotions] # Filter results by valid emotions
87
- return max(filtered, key=lambda x: x['score']) # Return the emotion with the highest score
88
 
89
- def _generate_prompt(text, emotion):
90
- """Create structured prompts for all emotion types"""
91
- prompt_templates = {
92
- "sadness": (
93
- "Sadness detected: {input}\n"
94
- "Required response structure:\n"
95
- "1. Empathetic acknowledgment\n2. Support offer\n3. Solution proposal\n"
96
- "Response:"
97
- ),
98
- "joy": (
99
- "Joy detected: {input}\n"
100
- "Required response structure:\n"
101
- "1. Enthusiastic thanks\n2. Positive reinforcement\n3. Future engagement\n"
102
- "Response:"
103
- ),
104
- "love": (
105
- "Affection detected: {input}\n"
106
- "Required response structure:\n"
107
- "1. Warm appreciation\n2. Community focus\n3. Exclusive benefit\n"
108
- "Response:"
109
- ),
110
- "anger": (
111
- "Anger detected: {input}\n"
112
- "Required response structure:\n"
113
- "1. Sincere apology\n2. Action steps\n3. Compensation\n"
114
- "Response:"
115
- ),
116
- "fear": (
117
- "Concern detected: {input}\n"
118
- "Required response structure:\n"
119
- "1. Reassurance\n2. Safety measures\n3. Support options\n"
120
- "Response:"
121
- ),
122
- "surprise": (
123
- "Surprise detected: {input}\n"
124
- "Required response structure:\n"
125
- "1. Acknowledge uniqueness\n2. Creative solution\n3. Follow-up\n"
126
- "Response:"
127
- )
128
- }
129
- return prompt_templates.get(emotion.lower(), "").format(input=text) # Format and return the appropriate prompt
130
 
131
- def _process_response(raw_text):
132
- """Clean and format the generated response"""
133
- # Extract text after last "Response:" marker
134
- processed = raw_text.split("Response:")[-1].strip()
135
-
136
- # Remove incomplete sentences
137
- if '.' in processed:
138
- processed = processed.rsplit('.', 1)[0] + '.' # Ensure the response ends with a period
139
-
140
- # Ensure length between 50-200 characters
141
- return processed[:200].strip() if len(processed) > 50 else "Thank you for your feedback. We value your input and will respond shortly."
142
 
143
- def _generate_text_response(input_text, models):
144
- """Generate optimized text response with timing controls"""
145
- # Emotion analysis
146
- emotion = _analyze_emotion(input_text, models['emotion']) # Analyze the emotion of user input
147
-
148
- # Prompt engineering
149
- prompt = _generate_prompt(input_text, emotion['label']) # Generate prompt based on detected emotion
150
-
151
- # Text generation with optimized parameters
152
- inputs = models['textgen_tokenizer'](prompt, return_tensors="pt").to('cpu') # Tokenize the prompt
153
- outputs = models['textgen_model'].generate(
154
- inputs.input_ids, # Input token IDs
155
- max_new_tokens=100, # Strict token limit for response length
156
- temperature=0.7, # Control randomness in text generation
157
- top_p=0.9, # Control diversity in sampling
158
- do_sample=True, # Enable sampling to generate varied responses
159
- pad_token_id=models['textgen_tokenizer'].eos_token_id # Use end-of-sequence token for padding
160
- )
161
-
162
- return _process_response(
163
- models['textgen_tokenizer'].decode(outputs[0], skip_special_tokens=True) # Decode and process the response
164
  )
 
 
 
 
 
165
 
166
- def _generate_audio_response(text, models):
167
- """Convert text to speech with performance optimizations"""
168
- # Process text input for TTS
169
- inputs = models['tts_processor'](text=text, return_tensors="pt") # Tokenize input text for TTS
170
-
171
- # Generate spectrogram
172
- spectrogram = models['tts_model'].generate_speech(
173
- inputs["input_ids"], # Input token IDs for TTS
174
- models['speaker_embeddings'] # Use preloaded speaker embeddings
175
- )
176
-
177
- # Generate waveform with optimizations
178
- with torch.no_grad(): # Disable gradient calculation for inference
179
- waveform = models['tts_vocoder'](spectrogram) # Generate audio waveform from spectrogram
180
-
181
- # Save audio file
182
- sf.write("response.wav", waveform.numpy(), samplerate=16000) # Save waveform as a WAV file
183
- return "response.wav" # Return the path to the saved audio file
184
 
185
- ##########################################
186
- # Main Application Flow
187
- ##########################################
188
- def main():
189
- """Primary execution flow"""
190
- # Load models once
191
- ml_models = _load_models() # Load all models and cache them
192
-
193
- # Display interface
194
- user_input = _display_interface() # Show the user input interface
195
-
196
- if user_input: # Check if user has entered input
197
- # Text generation stage
198
- with st.spinner("🔍 Analyzing emotions and generating response..."): # Show loading spinner
199
- text_response = _generate_text_response(user_input, ml_models) # Generate text response
200
-
201
- # Display results
202
- st.subheader("📄 Generated Response") # Subheader for response section
203
- st.markdown(f"```\n{text_response}\n```") # Display generated response in markdown format
204
-
205
- # Audio generation stage
206
- with st.spinner("🔊 Converting to speech..."): # Show loading spinner
207
- audio_file = _generate_audio_response(text_response, ml_models) # Generate audio response
208
- st.audio(audio_file, format="audio/wav") # Play audio file in the app
209
 
210
- if __name__ == "__main__":
211
- main() # Execute the main function when the script is run
 
1
+ import gradio as gr
2
+ import torch
3
+ import soundfile as sf
 
4
  from transformers import (
5
+ pipeline,
6
+ SpeechT5Processor,
7
+ SpeechT5ForTextToSpeech,
8
+ SpeechT5HifiGan,
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  )
12
+ from datasets import load_dataset
13
 
14
+ # Load all models globally (có thể chuyển sang lazy-load nếu muốn)
15
+ emotion_classifier = pipeline(
16
+ "text-classification",
17
+ model="Thea231/jhartmann_emotion_finetuning",
18
+ truncation=True
19
+ )
20
+ textgen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B", use_fast=True)
21
+ textgen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B", torch_dtype=torch.float16)
22
+ tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
23
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
24
+ tts_vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
25
+ speaker_embeddings = torch.tensor(
26
+ load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
27
+ ).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # Emotion prompt templates
30
+ PROMPT_TEMPLATES = {
31
+ "sadness": "Sadness detected: {input}\n1. Empathetic acknowledgment\n2. Support offer\n3. Solution proposal\nResponse:",
32
+ "joy": "Joy detected: {input}\n1. Enthusiastic thanks\n2. Positive reinforcement\n3. Future engagement\nResponse:",
33
+ "love": "Affection detected: {input}\n1. Warm appreciation\n2. Community focus\n3. Exclusive benefit\nResponse:",
34
+ "anger": "Anger detected: {input}\n1. Sincere apology\n2. Action steps\n3. Compensation\nResponse:",
35
+ "fear": "Concern detected: {input}\n1. Reassurance\n2. Safety measures\n3. Support options\nResponse:",
36
+ "surprise": "Surprise detected: {input}\n1. Acknowledge uniqueness\n2. Creative solution\n3. Follow-up\nResponse:"
37
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ def analyze_emotion(text):
40
+ scores = emotion_classifier(text, return_all_scores=True)[0]
41
+ valid = [e for e in scores if e['label'].lower() in PROMPT_TEMPLATES]
42
+ return max(valid, key=lambda x: x['score'])['label'].lower()
 
 
 
 
 
 
 
43
 
44
+ def generate_response(comment):
45
+ emotion = analyze_emotion(comment)
46
+ prompt = PROMPT_TEMPLATES[emotion].format(input=comment)
47
+ inputs = textgen_tokenizer(prompt, return_tensors="pt").to("cpu")
48
+ output_ids = textgen_model.generate(
49
+ inputs.input_ids,
50
+ max_new_tokens=100,
51
+ temperature=0.7,
52
+ top_p=0.9,
53
+ do_sample=True,
54
+ pad_token_id=textgen_tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
55
  )
56
+ raw_text = textgen_tokenizer.decode(output_ids[0], skip_special_tokens=True)
57
+ result = raw_text.split("Response:")[-1].strip()
58
+ if '.' in result:
59
+ result = result.rsplit('.', 1)[0] + '.'
60
+ return result[:200] if len(result) > 50 else "Cảm ơn bạn đã phản hồi. Chúng tôi sẽ xem xét kỹ lưỡng."
61
 
62
+ def generate_audio(text):
63
+ inputs = tts_processor(text=text, return_tensors="pt")
64
+ with torch.no_grad():
65
+ speech = tts_model.generate_speech(inputs["input_ids"], speaker_embeddings)
66
+ waveform = tts_vocoder(speech)
67
+ sf.write("response.wav", waveform.numpy(), 16000)
68
+ return "response.wav"
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ def full_pipeline(comment):
71
+ response = generate_response(comment)
72
+ audio_path = generate_audio(response)
73
+ return response, audio_path
74
+
75
+ # Gradio UI
76
+ demo = gr.Interface(
77
+ fn=full_pipeline,
78
+ inputs=gr.Textbox(label="💬 Nhập bình luận", placeholder="Ví dụ: Sản phẩm này có bền không vậy?"),
79
+ outputs=[
80
+ gr.Textbox(label="📄 Phản hồi AI"),
81
+ gr.Audio(label="🔊 Phát lại", type="filepath")
82
+ ],
83
+ title="Just Comment 🐠 (Gradio Edition)",
84
+ description="Phân tích cảm xúc + phản hồi AI + chuyển thành giọng nói"
85
+ )
 
 
 
 
 
 
 
 
86
 
87
+ demo.launch()