Commit
·
8b0a86e
1
Parent(s):
29dd04b
Change api.py
Browse files
api.py
CHANGED
|
@@ -64,7 +64,7 @@ def inference(model, device, comments: str | list):
|
|
| 64 |
if args["num_categories"] > 1:
|
| 65 |
batch_size, total_classes = outputs.shape
|
| 66 |
if total_classes % args["num_categories"] != 0:
|
| 67 |
-
raise ValueError(
|
| 68 |
|
| 69 |
classes_per_group = total_classes // args["num_categories"]
|
| 70 |
# Group every classes_per_group values along dim=1
|
|
|
|
| 64 |
if args["num_categories"] > 1:
|
| 65 |
batch_size, total_classes = outputs.shape
|
| 66 |
if total_classes % args["num_categories"] != 0:
|
| 67 |
+
raise ValueError("Error: Number of total classes in the batch must of divisible by the number of categories.")
|
| 68 |
|
| 69 |
classes_per_group = total_classes // args["num_categories"]
|
| 70 |
# Group every classes_per_group values along dim=1
|