Commit ·
0f7ecda
1
Parent(s): a0a72c0
Added normalization using torch
Browse files- handler.py +4 -5
handler.py
CHANGED
|
@@ -118,14 +118,13 @@ class EndpointHandler:
|
|
| 118 |
self.logger.info("Squeezing tensor")
|
| 119 |
batch_emb = frame_embedding.squeeze(0)
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
self.logger.info("Converting into numpy array")
|
| 122 |
batch_emb = batch_emb.cpu().detach().numpy()
|
| 123 |
|
| 124 |
-
# NORMALIZE
|
| 125 |
-
# self.logger.info("Normalizing numpy array")
|
| 126 |
-
# batch_emb = batch_emb.T / np.linalg.norm(batch_emb, axis=1)
|
| 127 |
-
# transpose back to (21, 512)
|
| 128 |
-
|
| 129 |
self.logger.info("Converting to list")
|
| 130 |
batch_emb = batch_emb.tolist()
|
| 131 |
|
|
|
|
| 118 |
self.logger.info("Squeezing tensor")
|
| 119 |
batch_emb = frame_embedding.squeeze(0)
|
| 120 |
|
| 121 |
+
# Normalize the embeddings
|
| 122 |
+
self.logger.info("Normalizing embeddings")
|
| 123 |
+
batch_emb = torch.nn.functional.normalize(batch_emb, p=2, dim=1)
|
| 124 |
+
|
| 125 |
self.logger.info("Converting into numpy array")
|
| 126 |
batch_emb = batch_emb.cpu().detach().numpy()
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
self.logger.info("Converting to list")
|
| 129 |
batch_emb = batch_emb.tolist()
|
| 130 |
|