sudo-paras-shah commited on
Commit
1ef230c
·
1 Parent(s): 0314565

Add streamlit home to environment

Browse files

Remove classification file, do everything in one file

Hopium Part 14

Files changed (2) hide show
  1. src/classification.py +0 -124
  2. src/streamlit_app.py +229 -162
src/classification.py DELETED
@@ -1,124 +0,0 @@
1
- import os
2
- import tempfile
3
-
4
- import matplotlib.pyplot as plt
5
- import numpy as np
6
-
7
- from nets import get_model_from_name
8
- from utils.utils import (cvtColor, get_classes, letterbox_image,
9
- preprocess_input)
10
-
11
- from huggingface_hub import hf_hub_download
12
-
13
- cache_dir = os.path.join(tempfile.gettempdir(), "hf_cache")
14
- os.makedirs(cache_dir, exist_ok=True)
15
-
16
- #--------------------------------------------#
17
- # 使用自己训练好的模型预测需要修改4个参数
18
- # model_path和classes_path、backbone
19
- # 和alpha都需要修改!
20
- #--------------------------------------------#
21
- class Classification(object):
22
- _defaults = {
23
- #--------------------------------------------------------------------------#
24
- # 使用自己训练好的模型进行预测一定要修改model_path和classes_path!
25
- # model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
26
- # 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
27
- #--------------------------------------------------------------------------#
28
- # "model_path" : 'model_data/mobilenet_2_5_224_tf_no_top.h5',
29
- "model_path" : hf_hub_download(repo_id="sudo-paras-shah/micro-expression-casme2", filename="ep089.weights.h5", cache_dir=cache_dir),
30
- "classes_path" : 'src/model_data/cls_classes.txt',
31
- #--------------------------------------------------------------------#
32
- # 输入的图片大小
33
- #--------------------------------------------------------------------#
34
- "input_shape" : [224, 224],
35
- #--------------------------------------------------------------------#
36
- # 所用模型种类:
37
- # mobilenet、resnet50、vgg16是常用的分类网络
38
- #--------------------------------------------------------------------#
39
- "backbone" : 'vgg16',
40
- #--------------------------------------------------------------------#
41
- # 当使用mobilenet的alpha值
42
- # 仅在backbone='mobilenet'的时候有效
43
- #--------------------------------------------------------------------#
44
- "alpha" : 0.25
45
- }
46
-
47
- @classmethod
48
- def get_defaults(cls, n):
49
- if n in cls._defaults:
50
- return cls._defaults[n]
51
- else:
52
- return "Unrecognized attribute name '" + n + "'"
53
-
54
- #---------------------------------------------------#
55
- # 初始化classification
56
- #---------------------------------------------------#
57
- def __init__(self, **kwargs):
58
- self.__dict__.update(self._defaults)
59
- for name, value in kwargs.items():
60
- setattr(self, name, value)
61
-
62
- #---------------------------------------------------#
63
- # 获得种类
64
- #---------------------------------------------------#
65
- self.class_names, self.num_classes = get_classes(self.classes_path)
66
- self.generate()
67
-
68
- #---------------------------------------------------#
69
- # 载入模型
70
- #---------------------------------------------------#
71
- def generate(self):
72
- model_path = os.path.expanduser(self.model_path)
73
- assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
74
-
75
- #---------------------------------------------------#
76
- # 载入模型与权值
77
- #---------------------------------------------------#
78
- if self.backbone == "mobilenet":
79
- self.model = get_model_from_name[self.backbone](input_shape = [self.input_shape[0], self.input_shape[1], 3], classes = self.num_classes, alpha = self.alpha)
80
- else:
81
- self.model = get_model_from_name[self.backbone](input_shape = [self.input_shape[0], self.input_shape[1], 3], classes = self.num_classes)
82
- self.model.load_weights(self.model_path)
83
- print('{} model, and classes {} loaded.'.format(model_path, self.class_names))
84
-
85
- #---------------------------------------------------#
86
- # 检测图片
87
- #---------------------------------------------------#
88
- def detect_image(self, image):
89
- #---------------------------------------------------------#
90
- # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
91
- # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
92
- #---------------------------------------------------------#
93
- image = cvtColor(image)
94
- # 查看数据类型
95
- # print(type(image))
96
- #---------------------------------------------------#
97
- # 对图片进行不失真的resize
98
- #---------------------------------------------------#
99
- image_data = letterbox_image(image, [self.input_shape[1], self.input_shape[0]])
100
- #---------------------------------------------------------#
101
- # 归一化+添加上batch_size维度
102
- #---------------------------------------------------------#
103
- image_data = np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0)
104
-
105
- #---------------------------------------------------#
106
- # 图片传入网络进行预测
107
- #---------------------------------------------------#
108
- preds = self.model.predict(image_data)[0]
109
- #---------------------------------------------------#
110
- # 获得所属种类
111
- #---------------------------------------------------#
112
- class_name = self.class_names[np.argmax(preds)]
113
- probability = np.max(preds)
114
-
115
- #---------------------------------------------------#
116
- # 绘图并写字
117
- #---------------------------------------------------#
118
-
119
- # plt.subplot(1, 1, 1)
120
- # plt.imshow(np.array(image))
121
- # plt.title('Class:%s Probability:%.3f' %(class_name, probability))
122
- # plt.show()
123
-
124
- return class_name, probability
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/streamlit_app.py CHANGED
@@ -1,5 +1,13 @@
1
  import os
2
- import subprocess
 
 
 
 
 
 
 
 
3
 
4
  import cv2
5
  import numpy as np
@@ -8,170 +16,229 @@ from PIL import Image
8
  import streamlit as st
9
  from streamlit_webrtc import VideoProcessorBase, webrtc_streamer
10
 
11
- from classification import Classification
12
-
13
- @st.cache_resource
14
- def get_model():
15
- return Classification
16
-
17
- classificator = get_model()
18
- face_cascade = cv2.CascadeClassifier(
19
- os.path.join('src', 'model_data', 'haarcascade_frontalface_alt.xml')
20
- )
21
-
22
- # Streamlit Title
23
- st.title("Real-Time Micro-Emotion Recognition")
24
 
25
- # Only Live Emotion Detection Mode
26
- st.write("Turn on your camera and detect emotions in real-time.")
27
-
28
- # Camera selection UI
29
- st.sidebar.header("Camera Settings")
30
- def get_connected_cameras():
31
  try:
32
- result = subprocess.run(
33
- ['v4l2-ctl', '--list-devices'],
34
- capture_output=True,
35
- text=True,
36
- check=True)
37
- devices = result.stdout.split('\n\n')
38
- camera_indices = []
39
- for device in devices:
40
- if "Camera" in device or "camera" in device:
41
- lines = device.split('\n')
42
- if len(lines) > 1:
43
- index_line = lines[1]
44
- index_str = index_line.strip().split(':')[0].strip()
45
- try:
46
- index = int(index_str[4:])
47
- camera_indices.append(index)
48
- except (ValueError, IndexError):
49
- pass
50
- return camera_indices
51
- except FileNotFoundError:
52
- return [0] # Fallback to default camera if v4l2-ctl is not available
53
- except subprocess.CalledProcessError:
54
- return [0]
55
-
56
- available_cameras = get_connected_cameras()
57
-
58
- if len(available_cameras) > 1:
59
- camera_index = st.sidebar.selectbox(
60
- "Select Camera Index",
61
- options=available_cameras,
62
- index=0,
63
- format_func=lambda x: f"Camera {x}"
64
- )
65
- else:
66
- camera_index = 0
67
- st.sidebar.write("Only one camera detected. Using default camera.")
68
-
69
- # --- Face detection and augmentation functions ---
70
- def face_detect(img):
71
- img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
72
- faces = face_cascade.detectMultiScale(
73
- img_gray,
74
- scaleFactor=1.1,
75
- minNeighbors=1,
76
- minSize=(30, 30)
77
- )
78
- return img, img_gray, faces
79
-
80
- # --- Emotion class mapping ---
81
- def map_emotion_to_class(emotion):
82
- positive = ['happiness', 'happy']
83
- negative = ['disgust', 'sadness', 'fear', 'sad', 'angry', 'disgusted']
84
- surprise = ['surprise']
85
- others = ['repression', 'tense', 'neutral', 'others']
86
- e = emotion.lower()
87
- if any(p in e for p in positive):
88
- return 'Positive'
89
- elif any(n in e for n in negative):
90
- return 'Negative'
91
- elif any(s in e for s in surprise):
92
- return 'Surprise'
93
- else:
94
- return 'Others'
95
-
96
- # --- Streamlit session state for emotion tracking ---
97
- if 'emotion_history' not in st.session_state:
98
- st.session_state['emotion_history'] = []
99
-
100
- # Video Processing Class
101
- class EmotionRecognitionProcessor(VideoProcessorBase):
102
- def __init__(self):
103
- self.last_class = None
104
- self.rapid_change_count = 0
105
-
106
- def recv(self, frame):
107
- border_color = (255, 0, 0) # Rectangle color (blue in BGR)
108
- font_color = (0, 0, 255) # Text color (red in BGR)
109
- img = frame.to_ndarray(format="bgr24")
110
- img_disp, img_gray, faces = face_detect(img)
111
- current_class = None
112
-
113
- if len(faces) == 0:
114
- cv2.putText(
115
- img_disp, 'No Face Detect.', (2, 20),
116
- cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1
117
  )
118
-
119
- for (x, y, w, h) in faces:
120
- x1, y1 = max(x - 10, 0), max(y - 10, 0)
121
- x2 = min(x + w + 10, img_disp.shape[1])
122
- y2 = min(y + h + 10, img_disp.shape[0])
123
-
124
- face_img_gray = img_gray[y1:y2, x1:x2]
125
- if face_img_gray.size == 0:
126
- continue
127
- face_img_pil = Image.fromarray(face_img_gray)
128
- emotion, probability = classificator.detect_image(face_img_pil)
129
- emotion_class = map_emotion_to_class(emotion)
130
-
131
- cv2.rectangle(
132
- img_disp,
133
- (x1, y1),
134
- (x2, y2),
135
- border_color,
136
- thickness=2
137
- )
138
- cv2.putText(
139
- img_disp, emotion, (x + 30, y - 30),
140
- cv2.FONT_HERSHEY_SIMPLEX, 1, font_color, 1
141
  )
142
- # Show probability
143
- cv2.putText(
144
- img_disp, str(round(probability, 3)), (x + 30, y - 50),
145
- cv2.FONT_HERSHEY_SIMPLEX, 0.3, font_color, 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  )
147
- current_class = emotion_class
148
-
149
- # Track emotion class changes
150
- if current_class:
151
- history = st.session_state['emotion_history']
152
- history.append(current_class)
153
- if len(history) > 10:
154
- history.pop(0)
155
- # Detect rapid changes
156
- if len(history) >= 3 and len(set(history[-3:])) > 1:
157
- self.rapid_change_count += 1
158
- else:
159
- self.rapid_change_count = 0
160
-
161
- return frame.from_ndarray(img_disp, format="bgr24")
162
-
163
- webrtc_streamer(
164
- key="emotion-detection",
165
- video_processor_factory=EmotionRecognitionProcessor,
166
- )
167
-
168
- # --- Streamlit alert for rapid emotion changes ---
169
- history = st.session_state['emotion_history']
170
- if len(history) >= 3 and len(set(history[-3:])) > 1:
171
- st.warning(
172
- "⚠️ Rapid changes in your detected emotional state were observed. "
173
- "Micro-expressions may not always reflect your true feelings. "
174
- "If you feel emotionally unstable or distressed, " \
175
- "consider reaching out to a mental health professional, "
176
- "talking it over with a close person or taking a break."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  )
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import sys
3
+ import tempfile
4
+
5
+ sys.stderr = open(os.devnull, 'w')
6
+ os.environ["HOME"] = "/tmp"
7
+ os.environ["STREAMLIT_HOME"] = "/tmp"
8
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
9
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
10
+ os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
11
 
12
  import cv2
13
  import numpy as np
 
16
  import streamlit as st
17
  from streamlit_webrtc import VideoProcessorBase, webrtc_streamer
18
 
19
+ import matplotlib.pyplot as plt
20
+ from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ import tensorflow as tf
23
+ gpus = tf.config.experimental.list_physical_devices('GPU')
24
+ if gpus:
 
 
 
25
  try:
26
+ for gpu in gpus:
27
+ tf.config.experimental.set_memory_growth(gpu, True)
28
+ except Exception as e:
29
+ print(e)
30
+
31
+ # --- Utility functions (from utils/utils.py) ---
32
+ # You must ensure these are implemented or import them if available.
33
+ from nets import get_model_from_name
34
+ from utils.utils import (cvtColor, get_classes, letterbox_image, preprocess_input)
35
+
36
+
37
+ # --- Classification class (merged from classification.py) ---
38
+ cache_dir = os.path.join(tempfile.gettempdir(), "hf_cache")
39
+ os.makedirs(cache_dir, exist_ok=True)
40
+
41
+ class Classification(object):
42
+ _defaults = {
43
+ "model_path": hf_hub_download(
44
+ repo_id="sudo-paras-shah/micro-expression-casme2",
45
+ filename="ep089.weights.h5",
46
+ cache_dir=cache_dir
47
+ ),
48
+ "classes_path": 'src/model_data/cls_classes.txt',
49
+ "input_shape": [224, 224],
50
+ "backbone": 'vgg16',
51
+ "alpha": 0.25
52
+ }
53
+
54
+ @classmethod
55
+ def get_defaults(cls, n):
56
+ if n in cls._defaults:
57
+ return cls._defaults[n]
58
+ else:
59
+ return "Unrecognized attribute name '" + n + "'"
60
+
61
+ def __init__(self, **kwargs):
62
+ self.__dict__.update(self._defaults)
63
+ for name, value in kwargs.items():
64
+ setattr(self, name, value)
65
+ self.class_names, self.num_classes = get_classes(self.classes_path)
66
+ self.generate()
67
+
68
+ def generate(self):
69
+ model_path = os.path.expanduser(self.model_path)
70
+ assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
71
+ if self.backbone == "mobilenet":
72
+ self.model = get_model_from_name[self.backbone](
73
+ input_shape=[self.input_shape[0], self.input_shape[1], 3],
74
+ classes=self.num_classes,
75
+ alpha=self.alpha
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  )
77
+ else:
78
+ self.model = get_model_from_name[self.backbone](
79
+ input_shape=[self.input_shape[0], self.input_shape[1], 3],
80
+ classes=self.num_classes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  )
82
+ self.model.load_weights(self.model_path)
83
+ print('{} model, and classes {} loaded.'.format(model_path, self.class_names))
84
+
85
+ def detect_image(self, image):
86
+ image = cvtColor(image)
87
+ image_data = letterbox_image(image, [self.input_shape[1], self.input_shape[0]])
88
+ image_data = np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0)
89
+ preds = self.model.predict(image_data)[0]
90
+ class_name = self.class_names[np.argmax(preds)]
91
+ probability = np.max(preds)
92
+ return class_name, probability
93
+
94
+ # --- Main Streamlit App ---
95
+ if __name__ == '__main__':
96
+ @st.cache_resource
97
+ def get_model():
98
+ return Classification()
99
+
100
+ classificator = get_model()
101
+ face_cascade = cv2.CascadeClassifier(
102
+ cv2.data.haarcascades + 'haarcascade_frontalface_alt.xml'
103
+ )
104
+
105
+ if face_cascade.empty():
106
+ st.error("Failed to load Haarcascade XML. Check the path.")
107
+
108
+ st.title("Real-Time Micro-Emotion Recognition")
109
+ st.write("Turn on your camera and detect emotions in real-time.")
110
+
111
+ def face_detect(img):
112
+ try:
113
+ img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
114
+ faces = face_cascade.detectMultiScale(
115
+ img_gray,
116
+ scaleFactor=1.1,
117
+ minNeighbors=1,
118
+ minSize=(30, 30)
119
  )
120
+ return img, img_gray, faces
121
+ except Exception as e:
122
+ st.error(f"OpenCV face detection error: {e}")
123
+ return img, np.zeros_like(img), []
124
+
125
+ def map_emotion_to_class(emotion):
126
+ positive = ['happiness', 'happy']
127
+ negative = ['disgust', 'sadness', 'fear', 'sad', 'angry', 'disgusted']
128
+ surprise = ['surprise']
129
+ others = ['repression', 'tense', 'neutral', 'others']
130
+ e = emotion.lower()
131
+ if any(p in e for p in positive):
132
+ return 'Positive'
133
+ elif any(n in e for n in negative):
134
+ return 'Negative'
135
+ elif any(s in e for s in surprise):
136
+ return 'Surprise'
137
+ else:
138
+ return 'Others'
139
+
140
+ if 'emotion_history' not in st.session_state:
141
+ st.session_state['emotion_history'] = []
142
+
143
+ class EmotionRecognitionProcessor(VideoProcessorBase):
144
+ def __init__(self):
145
+ self.last_class = None
146
+ self.rapid_change_count = 0
147
+ self.frame_count = 0
148
+ self.last_faces = []
149
+ self.last_img_gray = None
150
+ self.last_results = []
151
+
152
+ def recv(self, frame):
153
+ border_color = (255, 0, 0)
154
+ font_color = (0, 0, 255)
155
+ try:
156
+ img = frame.to_ndarray(format="bgr24")
157
+ self.frame_count += 1
158
+
159
+ # Only run detection every 5th frame, reuse previous results otherwise
160
+ if self.frame_count % 2 == 0:
161
+ img_disp, img_gray, faces = face_detect(img)
162
+ self.last_faces = faces
163
+ self.last_img_gray = img_gray
164
+ self.last_results = []
165
+ current_class = None
166
+
167
+ if len(faces) == 0:
168
+ cv2.putText(
169
+ img_disp, 'No Face Detect.', (2, 20),
170
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1
171
+ )
172
+
173
+ for (x, y, w, h) in faces:
174
+ x1, y1 = max(x - 10, 0), max(y - 10, 0)
175
+ x2 = min(x + w + 10, img_disp.shape[1])
176
+ y2 = min(y + h + 10, img_disp.shape[0])
177
+
178
+ face_img_gray = img_gray[y1:y2, x1:x2]
179
+ if face_img_gray.size == 0:
180
+ continue
181
+ face_img_pil = Image.fromarray(face_img_gray)
182
+ emotion, probability = classificator.detect_image(face_img_pil)
183
+ emotion_class = map_emotion_to_class(emotion)
184
+
185
+ self.last_results.append((x1, y1, x2, y2, emotion, probability, emotion_class))
186
+ current_class = emotion_class
187
+
188
+ if current_class:
189
+ history = st.session_state['emotion_history']
190
+ history.append(current_class)
191
+ if len(history) > 10:
192
+ history.pop(0)
193
+ if len(history) >= 3 and len(set(history[-3:])) > 1:
194
+ self.rapid_change_count += 1
195
+ else:
196
+ self.rapid_change_count = 0
197
+
198
+ else:
199
+ img_disp = img.copy()
200
+ img_gray = self.last_img_gray
201
+ faces = self.last_faces
202
+ for (x1, y1, x2, y2, emotion, probability, emotion_class) in self.last_results:
203
+ cv2.rectangle(
204
+ img_disp,
205
+ (x1, y1),
206
+ (x2, y2),
207
+ border_color,
208
+ thickness=2
209
+ )
210
+ cv2.putText(
211
+ img_disp, emotion, (x1 + 30, y1 - 30),
212
+ cv2.FONT_HERSHEY_SIMPLEX, 1, font_color, 1
213
+ )
214
+ cv2.putText(
215
+ img_disp, str(round(probability, 3)), (x1 + 30, y1 - 50),
216
+ cv2.FONT_HERSHEY_SIMPLEX, 0.3, font_color, 1
217
+ )
218
+
219
+ if len(faces) == 0:
220
+ cv2.putText(
221
+ img_disp, 'No Face Detect.', (2, 20),
222
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1
223
+ )
224
+
225
+ return frame.from_ndarray(img_disp, format="bgr24")
226
+ except Exception as e:
227
+ st.error(f"Error in video processing: {e}")
228
+ return frame
229
+
230
+ webrtc_streamer(
231
+ key="emotion-detection",
232
+ video_processor_factory=EmotionRecognitionProcessor,
233
+ media_stream_constraints={"video": True, "audio": False},
234
  )
235
+
236
+ history = st.session_state['emotion_history']
237
+ if len(history) >= 3 and len(set(history[-3:])) > 1:
238
+ st.warning(
239
+ "⚠️ Rapid changes in your detected emotional state were observed. "
240
+ "Micro-expressions may not always reflect your true feelings. "
241
+ "If you feel emotionally unstable or distressed, "
242
+ "consider reaching out to a mental health professional, "
243
+ "talking it over with a close person or taking a break."
244
+ )