Commit ·
29dd04b
1
Parent(s): a93cb2c
Change api.py
Browse files
api.py
CHANGED
|
@@ -64,7 +64,7 @@ def inference(model, device, comments: str | list):
|
|
| 64 |
if args["num_categories"] > 1:
|
| 65 |
batch_size, total_classes = outputs.shape
|
| 66 |
if total_classes % args["num_categories"] != 0:
|
| 67 |
-
raise ValueError(f"Error: Number of total classes in the batch must of divisible by {args[
|
| 68 |
|
| 69 |
classes_per_group = total_classes // args["num_categories"]
|
| 70 |
# Group every classes_per_group values along dim=1
|
|
|
|
| 64 |
if args["num_categories"] > 1:
|
| 65 |
batch_size, total_classes = outputs.shape
|
| 66 |
if total_classes % args["num_categories"] != 0:
|
| 67 |
+
raise ValueError(f"Error: Number of total classes in the batch must of divisible by {args['num_categories']}")
|
| 68 |
|
| 69 |
classes_per_group = total_classes // args["num_categories"]
|
| 70 |
# Group every classes_per_group values along dim=1
|