Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -179,9 +179,16 @@ discriminator.eval()
|
|
| 179 |
|
| 180 |
def get_prediction(embeddings):
|
| 181 |
with torch.no_grad():
|
|
|
|
| 182 |
last_rep, logits, probs = discriminator(embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
-
|
| 185 |
predicted_labels = predicted_labels.cpu().numpy()
|
| 186 |
return predicted_labels
|
| 187 |
|
|
|
|
| 179 |
|
| 180 |
def get_prediction(embeddings):
|
| 181 |
with torch.no_grad():
|
| 182 |
+
# Forward pass through the discriminator to get the logits and probabilities
|
| 183 |
last_rep, logits, probs = discriminator(embeddings)
|
| 184 |
+
|
| 185 |
+
# Filter logits to ignore the last dimension (assuming you only care about the first two)
|
| 186 |
+
filtered_logits = logits[:, 0:-1]
|
| 187 |
+
|
| 188 |
+
# Get the predicted labels using the filtered logits
|
| 189 |
+
_, predicted_labels = torch.max(filtered_logits, dim=-1)
|
| 190 |
|
| 191 |
+
# Convert to numpy array if needed
|
| 192 |
predicted_labels = predicted_labels.cpu().numpy()
|
| 193 |
return predicted_labels
|
| 194 |
|