Spaces:
Running
Running
fx
Browse files
app.py
CHANGED
|
@@ -68,7 +68,7 @@ class AgeGenderModel(Wav2Vec2PreTrainedModel):
|
|
| 68 |
logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
|
| 69 |
|
| 70 |
return hidden_states, logits_age, logits_gender
|
| 71 |
-
|
| 72 |
# AgeGenderModel.forward() is switched to accept computed frozen CNN7 features from ExpressioNmodel
|
| 73 |
|
| 74 |
def _forward(
|
|
@@ -178,7 +178,7 @@ age_gender_model.wav2vec2.forward = types.MethodType(_forward, age_gender_model)
|
|
| 178 |
expression_model.wav2vec2.forward = types.MethodType(_forward_and_cnn7, expression_model)
|
| 179 |
|
| 180 |
def process_func(x: np.ndarray, sampling_rate: int) -> typing.Tuple[str, dict, str]:
|
| 181 |
-
|
| 182 |
# batch audio
|
| 183 |
y = expression_processor(x, sampling_rate=sampling_rate)
|
| 184 |
y = y['input_values'][0]
|
|
@@ -227,7 +227,7 @@ def recognize(input_file: str) -> typing.Tuple[str, dict, str]:
|
|
| 227 |
return process_func(signal, target_rate)
|
| 228 |
|
| 229 |
|
| 230 |
-
def
|
| 231 |
r"""3D pixel plot of arousal, dominance, valence."""
|
| 232 |
# Voxels per dimension
|
| 233 |
voxels = 7
|
|
@@ -271,6 +271,105 @@ def plot_expression(arousal, dominance, valence):
|
|
| 271 |
verticalalignment="top",
|
| 272 |
)
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
|
| 276 |
description = (
|
|
|
|
| 68 |
logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
|
| 69 |
|
| 70 |
return hidden_states, logits_age, logits_gender
|
| 71 |
+
|
| 72 |
# AgeGenderModel.forward() is switched to accept computed frozen CNN7 features from ExpressioNmodel
|
| 73 |
|
| 74 |
def _forward(
|
|
|
|
| 178 |
expression_model.wav2vec2.forward = types.MethodType(_forward_and_cnn7, expression_model)
|
| 179 |
|
| 180 |
def process_func(x: np.ndarray, sampling_rate: int) -> typing.Tuple[str, dict, str]:
|
| 181 |
+
|
| 182 |
# batch audio
|
| 183 |
y = expression_processor(x, sampling_rate=sampling_rate)
|
| 184 |
y = y['input_values'][0]
|
|
|
|
| 227 |
return process_func(signal, target_rate)
|
| 228 |
|
| 229 |
|
| 230 |
+
def plot_expression_RIGID(arousal, dominance, valence):
|
| 231 |
r"""3D pixel plot of arousal, dominance, valence."""
|
| 232 |
# Voxels per dimension
|
| 233 |
voxels = 7
|
|
|
|
| 271 |
verticalalignment="top",
|
| 272 |
)
|
| 273 |
|
| 274 |
+
COLORMAP = plt.get_cmap('coolwarm')
|
| 275 |
+
N_PIX = 5
|
| 276 |
+
|
| 277 |
+
matplotlib.rcParams['mathtext.fontset'] = 'stix'
|
| 278 |
+
matplotlib.rcParams['font.family'] = 'STIXGeneral'
|
| 279 |
+
|
| 280 |
+
def explode(data):
|
| 281 |
+
'''replicate 16 x 16 x 16 cube to edges array 31 x 31 x 31'''
|
| 282 |
+
size = np.array(data.shape)*2
|
| 283 |
+
data_e = np.zeros(size - 1, dtype=data.dtype)
|
| 284 |
+
data_e[::2, ::2, ::2] = data
|
| 285 |
+
return data_e
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def plot_expression(arousal, dominance, valence):
|
| 289 |
+
|
| 290 |
+
'''_h = cuda tensor (N_PIX, N_PIX, N_PIX)'''
|
| 291 |
+
|
| 292 |
+
N_PIX=5
|
| 293 |
+
_h = np.random.rand(N_PIX, N_PIX, N_PIX) * 1e-3
|
| 294 |
+
adv = np.array([arousal, .994 - dominance, valence]).clip(0, .99)
|
| 295 |
+
arousal, dominance, valence = (adv * N_PIX).astype(np.int64) # find voxel
|
| 296 |
+
_h[arousal, dominance, valence] = .22
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
filled = np.ones((N_PIX, N_PIX, N_PIX), dtype=bool)
|
| 302 |
+
|
| 303 |
+
# upscale the above voxel image, leaving gaps
|
| 304 |
+
filled_2 = explode(filled)
|
| 305 |
+
|
| 306 |
+
# Shrink the gaps
|
| 307 |
+
x, y, z = np.indices(np.array(filled_2.shape) + 1).astype(float) // 2
|
| 308 |
+
x[1::2, :, :] += 1
|
| 309 |
+
y[:, 1::2, :] += 1
|
| 310 |
+
z[:, :, 1::2] += 1
|
| 311 |
+
|
| 312 |
+
ax = plt.figure().add_subplot(projection='3d')
|
| 313 |
+
|
| 314 |
+
f_2 = np.ones([2 * N_PIX - 1,
|
| 315 |
+
2 * N_PIX - 1,
|
| 316 |
+
2 * N_PIX - 1, 4], dtype=np.float64)
|
| 317 |
+
f_2[:, :, :, 3] = explode(_h)
|
| 318 |
+
cm = plt.get_cmap('cool')
|
| 319 |
+
f_2[:, :, :, :3] = cm(f_2[:, :, :, 3])[..., :3]
|
| 320 |
+
|
| 321 |
+
f_2[:, :, :, 3] = f_2[:, :, :, 3].clip(.01, .74)
|
| 322 |
+
|
| 323 |
+
print(f_2.shape, 'f_2 AAAA')
|
| 324 |
+
ecolors_2 = f_2
|
| 325 |
+
|
| 326 |
+
ax.voxels(x, y, z, filled_2, facecolors=f_2, edgecolors=.006 * ecolors_2)
|
| 327 |
+
ax.set_aspect('equal')
|
| 328 |
+
ax.set_zticks([0, N_PIX])
|
| 329 |
+
ax.set_xticks([0, N_PIX])
|
| 330 |
+
ax.set_yticks([0, N_PIX])
|
| 331 |
+
|
| 332 |
+
ax.set_zticklabels([f'{n/N_PIX:.2f}'[0:] for n in ax.get_zticks()])
|
| 333 |
+
ax.set_zlabel('valence', fontsize=10, labelpad=0)
|
| 334 |
+
ax.set_xticklabels([f'{n/N_PIX:.2f}' for n in ax.get_xticks()])
|
| 335 |
+
ax.set_xlabel('arousal', fontsize=10, labelpad=7)
|
| 336 |
+
# The y-axis rotation is corrected here from 275 to 90 degrees
|
| 337 |
+
ax.set_yticklabels([f'{1-n/N_PIX:.2f}' for n in ax.get_yticks()], rotation=90)
|
| 338 |
+
ax.set_ylabel('dominance', fontsize=10, labelpad=10)
|
| 339 |
+
ax.grid(False)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
ax.plot([N_PIX, N_PIX], [0, N_PIX + .2], [N_PIX, N_PIX], 'g', linewidth=1)
|
| 345 |
+
ax.plot([0, N_PIX], [N_PIX, N_PIX + .24], [N_PIX, N_PIX], 'k', linewidth=1)
|
| 346 |
+
|
| 347 |
+
# Bottom face lines
|
| 348 |
+
# ax.plot([0, N_PIX + line_extension], [0, 0], [0, 0], 'y', linewidth=1)
|
| 349 |
+
# ax.plot([0, 0], [0, N_PIX + line_extension], [0, 0], 'r', linewidth=1)
|
| 350 |
+
# ax.plot([N_PIX, N_PIX + line_extension], [0, N_PIX], [0, 0], 'm', linewidth=1)
|
| 351 |
+
# ax.plot([0, N_PIX], [N_PIX, N_PIX + line_extension], [0, 0], 'c', linewidth=1)
|
| 352 |
+
|
| 353 |
+
# Vertical lines
|
| 354 |
+
# ax.plot([0, 0], [0, 0], [0, N_PIX + line_extension], 'b', linewidth=1)
|
| 355 |
+
# ax.plot([N_PIX, N_PIX], [0, 0], [0, N_PIX + line_extension], 'w', linewidth=1)
|
| 356 |
+
# ax.plot([N_PIX, N_PIX], [N_PIX, N_PIX], [0, N_PIX + line_extension], 'orange', linewidth=1)
|
| 357 |
+
# ax.plot([0, 0], [N_PIX, N_PIX], [0, N_PIX + line_extension], 'lime', linewidth=1)
|
| 358 |
+
|
| 359 |
+
# # Missing lines on the top face
|
| 360 |
+
ax.plot([0, 0], [0, N_PIX], [N_PIX, N_PIX], 'darkred', linewidth=1)
|
| 361 |
+
ax.plot([0, N_PIX], [0, 0], [N_PIX, N_PIX], 'darkblue', linewidth=1)
|
| 362 |
+
|
| 363 |
+
# Set pane colors after plotting the lines
|
| 364 |
+
ax.w_xaxis.set_pane_color((0.8, 0.8, 0.8, 0.5))
|
| 365 |
+
ax.w_yaxis.set_pane_color((0.8, 0.8, 0.8, 0.5))
|
| 366 |
+
ax.w_zaxis.set_pane_color((0.8, 0.8, 0.8, 0.0))
|
| 367 |
+
|
| 368 |
+
# Restore the limits to prevent the plot from expanding
|
| 369 |
+
ax.set_xlim(0, N_PIX)
|
| 370 |
+
ax.set_ylim(0, N_PIX)
|
| 371 |
+
ax.set_zlim(0, N_PIX)
|
| 372 |
+
# ------
|
| 373 |
|
| 374 |
|
| 375 |
description = (
|