Dddrl commited on
Commit
f2da24d
·
verified ·
1 Parent(s): e273214

Upload 2 files

Browse files
Files changed (1) hide show
  1. app.py +89 -51
app.py CHANGED
@@ -6,98 +6,134 @@ import numpy as np
6
  import gradio as gr
7
  import openai
8
  import os
 
 
9
 
10
- # Emotion categories
11
- emotions = ["Neutral", "Happy", "Angry", "Sad", "Surprise"]
12
 
13
- # CNN model definition
 
 
 
 
 
 
14
  class CNN(nn.Module):
15
  def __init__(self, num_classes):
16
  super(CNN, self).__init__()
17
  self.name = "CNN"
 
18
  self.conv1 = nn.Conv1d(in_channels=768, out_channels=256, kernel_size=3, padding=1)
19
  self.bn1 = nn.BatchNorm1d(256)
20
  self.pool = nn.AdaptiveMaxPool1d(output_size=96)
 
21
  self.conv2 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
22
  self.bn2 = nn.BatchNorm1d(128)
 
23
  self.conv3 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
24
  self.bn3 = nn.BatchNorm1d(64)
 
25
  self.fc1 = nn.Linear(64 * 96, 128)
26
  self.dropout = nn.Dropout(0.5)
27
  self.fc2 = nn.Linear(128, num_classes)
28
 
29
  def forward(self, x):
30
- x = x.unsqueeze(1)
31
  x = x.permute(0, 2, 1)
 
32
  x = F.relu(self.bn1(self.conv1(x)))
 
33
  x = self.pool(x)
 
 
34
  x = F.relu(self.bn2(self.conv2(x)))
 
35
  x = self.pool(x)
 
 
36
  x = F.relu(self.bn3(self.conv3(x)))
 
37
  x = self.pool(x)
 
 
38
  x = x.view(x.size(0), -1)
39
  x = F.relu(self.fc1(x))
40
  x = self.dropout(x)
41
  x = self.fc2(x)
 
42
  return x
43
 
44
- # Load the trained model
45
- model = CNN(num_classes=5)
46
- model.load_state_dict(torch.load("best_model_CNN_bs32_lr0.0005_epoch9_acc0.9238.pth", map_location="cpu"))
47
  model.eval()
 
48
 
49
- # Extract features from audio file
50
- def extract_feature(audio_path):
51
- y, sr = librosa.load(audio_path, sr=16000)
52
- mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=40)
53
- max_len = 200
54
- if mfcc.shape[1] > max_len:
55
- mfcc = mfcc[:, :max_len]
56
- else:
57
- pad_width = max_len - mfcc.shape[1]
58
- mfcc = np.pad(mfcc, ((0, 0), (0, pad_width)), mode='constant')
59
- feature = np.tile(mfcc, (int(768 / 40), 1))
60
- feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0)
61
- return feature
62
-
63
- # Full pipeline: emotion detection + GPT response
64
- def predict_and_reply(audio_path):
65
- model.eval()
66
-
67
- # Load and preprocess audio
68
- feature = extract_feature(audio_path)
69
-
70
- # Move model and input to correct device
71
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
- model.to(device)
73
- feature = feature.to(device)
74
-
75
- # Predict
76
- with torch.no_grad():
77
- output = model(feature)
78
- pred = torch.argmax(output, dim=1).item()
79
- emotion = emotions[pred]
80
-
81
 
82
- prompt = f"The user sounds {emotion.lower()}. What would you like to say to them?"
 
 
83
 
 
 
 
 
 
 
 
 
 
84
  try:
85
- openai.api_key = os.getenv("OPENAI_API_KEY", "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") # Replace with real key or env var
86
  response = openai.ChatCompletion.create(
87
- model="gpt-3.5-turbo",
88
  messages=[
89
  {"role": "system", "content": "You are a helpful assistant that provides entertainment recommendations."},
90
  {"role": "user", "content": prompt}
91
- ]
 
 
92
  )
93
- reply = response['choices'][0]['message']['content']
94
  except Exception as e:
95
- reply = f" GPT Error: {str(e)}"
96
 
97
- return f"🎧 Detected Emotion: **{emotion}**\n\n💬 GPT Says:\n{reply}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- #️ Gradio app layout
100
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  gr.Markdown("## 🎙️ 情绪检测 + 聊天机器人")
102
  gr.Markdown("上传或录制一段简短的语音片段,我会识别你的情绪,并请求 GPT 做出共情的回应。")
103
 
@@ -106,8 +142,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
106
  audio_input = gr.Audio(label="🎧 语音输入", type="filepath", format="wav")
107
  submit_btn = gr.Button("🚀 提交")
108
  with gr.Column():
109
- output_text = gr.Markdown(label="💬 GPT 回复")
 
 
 
110
 
111
- submit_btn.click(fn=predict_and_reply, inputs=audio_input, outputs=output_text)
112
 
113
- demo.launch()
 
6
  import gradio as gr
7
  import openai
8
  import os
9
+ from transformers import Wav2Vec2FeatureExtractor
10
+ from transformers import Wav2Vec2Model
11
 
12
+ # ----------------- Setup ---------------------
 
13
 
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ wav2vec2_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(device)
16
+ # Load Wav2Vec2 feature extractor
17
+ model_name = "facebook/wav2vec2-base"
18
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
19
+
20
+ # --------------- Load Emotion Classification Model -----------------
21
  class CNN(nn.Module):
22
  def __init__(self, num_classes):
23
  super(CNN, self).__init__()
24
  self.name = "CNN"
25
+
26
  self.conv1 = nn.Conv1d(in_channels=768, out_channels=256, kernel_size=3, padding=1)
27
  self.bn1 = nn.BatchNorm1d(256)
28
  self.pool = nn.AdaptiveMaxPool1d(output_size=96)
29
+
30
  self.conv2 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
31
  self.bn2 = nn.BatchNorm1d(128)
32
+
33
  self.conv3 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
34
  self.bn3 = nn.BatchNorm1d(64)
35
+
36
  self.fc1 = nn.Linear(64 * 96, 128)
37
  self.dropout = nn.Dropout(0.5)
38
  self.fc2 = nn.Linear(128, num_classes)
39
 
40
  def forward(self, x):
41
+ # x = x.unsqueeze(1)
42
  x = x.permute(0, 2, 1)
43
+
44
  x = F.relu(self.bn1(self.conv1(x)))
45
+ #print(f"Before pooling 1, x shape: {x.shape}")
46
  x = self.pool(x)
47
+ #print(f"After pooling 1, x shape: {x.shape}")
48
+
49
  x = F.relu(self.bn2(self.conv2(x)))
50
+ #print(f"Before pooling 2, x shape: {x.shape}")
51
  x = self.pool(x)
52
+ #print(f"After pooling 2, x shape: {x.shape}")
53
+
54
  x = F.relu(self.bn3(self.conv3(x)))
55
+ #print(f"Before pooling 3, x shape: {x.shape}")
56
  x = self.pool(x)
57
+ #print(f"After pooling 3, x shape: {x.shape}")
58
+
59
  x = x.view(x.size(0), -1)
60
  x = F.relu(self.fc1(x))
61
  x = self.dropout(x)
62
  x = self.fc2(x)
63
+
64
  return x
65
 
66
+ model = CNN(5)
67
+ model.load_state_dict(torch.load("best_model_CNN_bs32_lr0.0005_epoch9_acc0.9238.pth", map_location=torch.device("cpu")))
 
68
  model.eval()
69
+ wav2vec2_model.eval()
70
 
71
+ label_map = {0: "Neutral", 1: "Happy", 2: "Angry", 3: "Sad", 4: "Surprise"}
72
+
73
+ # ------------------ ChatGPT API Setup ---------------------
74
+ openai.api_key = "" # Use env variable or secret manager in production!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ def create_prompt_from_label(label):
77
+ return f"""
78
+ The user is currently feeling {label.lower()}. Start by briefly and thoughtfully acknowledging how someone might feel when experiencing this emotion.
79
 
80
+ Then, as a recommendation system, suggest 3 pieces of entertainment content—such as movies, music, or shows—that align with or help support this mood.
81
+
82
+ Ensure your tone is friendly and supportive, and make the recommendations short, engaging, and tailored to the {label.lower()} emotional state.
83
+
84
+ You can add some lovely emoji to let it become warm.
85
+ """
86
+
87
+ def get_recommendations(label):
88
+ prompt = create_prompt_from_label(label)
89
  try:
 
90
  response = openai.ChatCompletion.create(
91
+ model="gpt-4",
92
  messages=[
93
  {"role": "system", "content": "You are a helpful assistant that provides entertainment recommendations."},
94
  {"role": "user", "content": prompt}
95
+ ],
96
+ max_tokens=500,
97
+ temperature=0.7
98
  )
99
+ return response['choices'][0]['message']['content'].strip()
100
  except Exception as e:
101
+ return f"An error occurred: {e}"
102
 
103
+ # ----------------- Inference Pipeline ---------------------
104
+ def process_audio_and_recommend(file_path):
105
+ audio, sr = librosa.load(file_path, sr=16000)
106
+ max_duration = 5
107
+ max_samples = int(max_duration * sr)
108
+ if len(audio) > max_samples:
109
+ audio = audio[:max_samples]
110
+
111
+ inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
112
+ input_values = inputs["input_values"].to(device)
113
+
114
+ with torch.no_grad():
115
+ # Get real Wav2Vec2 embeddings
116
+ features = wav2vec2_model(input_values).last_hidden_state # shape: [1, seq_len, 768]
117
+ outputs = model(features) # PASS DIRECTLY, no extra dim needed
118
 
119
+ pred_idx = torch.argmax(outputs, dim=1).item()
120
+ emotion = label_map[pred_idx]
121
+ recommendations = get_recommendations(emotion)
122
+ return f"🧠 Detected Emotion: {emotion}", recommendations
123
+
124
+
125
+ # ----------------- Gradio UI ---------------------
126
+ # interface = gr.Interface(
127
+ # fn=process_audio_and_recommend,
128
+ # inputs=gr.Audio(type="filepath"),
129
+ # outputs=["text", "text"],
130
+ # title="🎙️ Emotion-Based Entertainment Bot",
131
+ # description="Upload your voice. We'll detect your emotion and ChatGPT will suggest entertainment!"
132
+ # )
133
+
134
+ # interface.launch()
135
+
136
+ with gr.Blocks(theme=gr.themes.Soft()) as interface:
137
  gr.Markdown("## 🎙️ 情绪检测 + 聊天机器人")
138
  gr.Markdown("上传或录制一段简短的语音片段,我会识别你的情绪,并请求 GPT 做出共情的回应。")
139
 
 
142
  audio_input = gr.Audio(label="🎧 语音输入", type="filepath", format="wav")
143
  submit_btn = gr.Button("🚀 提交")
144
  with gr.Column():
145
+ output_text_1 = gr.Text(label="🧠 检测情绪")
146
+ output_text_2 = gr.Text(label="💬 GPT 回复")
147
+
148
+ submit_btn.click(fn=process_audio_and_recommend, inputs=audio_input, outputs=[output_text_1, output_text_2])
149
 
150
+ interface.launch()
151