Ellie5757575757 commited on
Commit
b8b604c
·
verified ·
1 Parent(s): cda796a

Update output.py

Browse files
Files changed (1) hide show
  1. output.py +18 -1
output.py CHANGED
@@ -7,13 +7,14 @@ Aphasia classification inference (cleaned).
7
  - Adds predict_from_chajson(json_path, ...) helper
8
  """
9
 
10
- import json
11
  import os
12
  import math
13
  from dataclasses import dataclass
14
  from typing import Dict, List, Optional, Tuple
15
  from collections import defaultdict
16
 
 
17
  import numpy as np
18
  import torch
19
  import torch.nn as nn
@@ -594,6 +595,22 @@ def predict_from_chajson(model_dir: str, chajson_path: str, output_file: Optiona
594
  pass
595
  return out
596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
 
598
  # ---------- CLI ----------
599
 
 
7
  - Adds predict_from_chajson(json_path, ...) helper
8
  """
9
 
10
+ import json as _json
11
  import os
12
  import math
13
  from dataclasses import dataclass
14
  from typing import Dict, List, Optional, Tuple
15
  from collections import defaultdict
16
 
17
+
18
  import numpy as np
19
  import torch
20
  import torch.nn as nn
 
595
  pass
596
  return out
597
 
598
+ def format_result(pred: dict, style: str = "json") -> str:
599
+ """Back-compat formatter. 'pred' is the dict returned by predict_*."""
600
+ if style == "json":
601
+ return _json.dumps(pred, ensure_ascii=False, indent=2)
602
+ # simple text summary
603
+ if isinstance(pred, dict) and "summary" in pred:
604
+ s = pred["summary"]
605
+ lines = [
606
+ f"Total sentences: {pred.get('total_sentences', 0)}",
607
+ f"Avg confidence: {s.get('average_confidence', 'N/A')}",
608
+ f"Avg fluency: {s.get('average_fluency_score', 'N/A')}",
609
+ f"Most common: {s.get('most_common_prediction', 'N/A')}",
610
+ ]
611
+ return "\n".join(lines)
612
+ return str(pred)
613
+
614
 
615
  # ---------- CLI ----------
616