hetchyy Claude Opus 4.6 commited on
Commit
d640378
·
1 Parent(s): c836860

Fix session persistence: pickle VAD artifacts to preserve tensor types

Browse files

clean_speech_intervals() expects the original torch.Tensor types from
VAD output. The previous np.save/np.load roundtrip converted them to
numpy arrays, causing resegment_session to crash. Now uses pickle for
speech_intervals and is_complete, keeping np.save only for audio.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. src/api/session_api.py +43 -33
src/api/session_api.py CHANGED
@@ -8,6 +8,7 @@ re-uploads and re-inference.
8
  import hashlib
9
  import json
10
  import os
 
11
  import re
12
  import shutil
13
  import time
@@ -36,24 +37,8 @@ def _validate_id(audio_id: str) -> bool:
36
  return isinstance(audio_id, str) and bool(_VALID_ID.match(audio_id))
37
 
38
 
39
- def _is_expired(meta: dict) -> bool:
40
- return (time.time() - meta.get("created_at", 0)) > SESSION_EXPIRY_SECONDS
41
-
42
-
43
- def _read_metadata(session_path):
44
- meta_path = session_path / "metadata.json"
45
- if not meta_path.exists():
46
- return None
47
- with open(meta_path) as f:
48
- return json.load(f)
49
-
50
-
51
- def _write_metadata(session_path, meta: dict):
52
- """Atomic write via temp file + os.replace."""
53
- tmp = session_path / "metadata.tmp"
54
- with open(tmp, "w") as f:
55
- json.dump(meta, f)
56
- os.replace(tmp, session_path / "metadata.json")
57
 
58
 
59
  def _sweep_expired():
@@ -68,8 +53,8 @@ def _sweep_expired():
68
  for entry in SESSION_DIR.iterdir():
69
  if not entry.is_dir():
70
  continue
71
- meta = _read_metadata(entry)
72
- if meta is None or _is_expired(meta):
73
  shutil.rmtree(entry, ignore_errors=True)
74
 
75
 
@@ -78,23 +63,37 @@ def _intervals_hash(intervals) -> str:
78
 
79
 
80
  def create_session(audio, speech_intervals, is_complete, intervals, model_name):
81
- """Persist session data and return audio_id (32-char hex UUID)."""
 
 
 
 
 
82
  _sweep_expired()
83
  audio_id = uuid.uuid4().hex
84
  path = _session_dir(audio_id)
85
  path.mkdir(parents=True, exist_ok=True)
86
 
 
87
  np.save(path / "audio.npy", audio)
88
- np.save(path / "speech_intervals.npy", speech_intervals)
89
 
 
 
 
 
 
 
90
  meta = {
91
- "is_complete": bool(is_complete),
92
  "intervals": intervals,
93
  "model_name": model_name,
94
  "intervals_hash": _intervals_hash(intervals),
95
- "created_at": time.time(),
96
  }
97
- _write_metadata(path, meta)
 
 
 
 
 
98
  return audio_id
99
 
100
 
@@ -105,18 +104,24 @@ def load_session(audio_id):
105
  path = _session_dir(audio_id)
106
  if not path.exists():
107
  return None
108
- meta = _read_metadata(path)
109
- if meta is None or _is_expired(meta):
 
110
  shutil.rmtree(path, ignore_errors=True)
111
  return None
112
 
113
  audio = np.load(path / "audio.npy")
114
- speech_intervals = np.load(path / "speech_intervals.npy")
 
 
 
 
 
115
 
116
  return {
117
  "audio": audio,
118
- "speech_intervals": speech_intervals,
119
- "is_complete": meta["is_complete"],
120
  "intervals": meta["intervals"],
121
  "model_name": meta["model_name"],
122
  "intervals_hash": meta.get("intervals_hash", ""),
@@ -127,15 +132,20 @@ def load_session(audio_id):
127
  def update_session(audio_id, *, intervals=None, model_name=None):
128
  """Update mutable session fields (intervals, model_name)."""
129
  path = _session_dir(audio_id)
130
- meta = _read_metadata(path)
131
- if meta is None:
132
  return
 
 
133
  if intervals is not None:
134
  meta["intervals"] = intervals
135
  meta["intervals_hash"] = _intervals_hash(intervals)
136
  if model_name is not None:
137
  meta["model_name"] = model_name
138
- _write_metadata(path, meta)
 
 
 
139
 
140
 
141
  # ---------------------------------------------------------------------------
 
8
  import hashlib
9
  import json
10
  import os
11
+ import pickle
12
  import re
13
  import shutil
14
  import time
 
37
  return isinstance(audio_id, str) and bool(_VALID_ID.match(audio_id))
38
 
39
 
40
+ def _is_expired(created_at: float) -> bool:
41
+ return (time.time() - created_at) > SESSION_EXPIRY_SECONDS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
 
44
  def _sweep_expired():
 
53
  for entry in SESSION_DIR.iterdir():
54
  if not entry.is_dir():
55
  continue
56
+ ts_file = entry / "created_at"
57
+ if not ts_file.exists() or _is_expired(float(ts_file.read_text())):
58
  shutil.rmtree(entry, ignore_errors=True)
59
 
60
 
 
63
 
64
 
65
  def create_session(audio, speech_intervals, is_complete, intervals, model_name):
66
+ """Persist session data and return audio_id (32-char hex UUID).
67
+
68
+ Uses pickle for VAD artifacts (speech_intervals, is_complete) to
69
+ preserve exact types (torch.Tensor etc.) expected by the segmenter.
70
+ Uses np.save for the audio array (large, always float32 numpy).
71
+ """
72
  _sweep_expired()
73
  audio_id = uuid.uuid4().hex
74
  path = _session_dir(audio_id)
75
  path.mkdir(parents=True, exist_ok=True)
76
 
77
+ # Audio is always a float32 numpy array after preprocessing
78
  np.save(path / "audio.npy", audio)
 
79
 
80
+ # VAD artifacts: preserve original types via pickle
81
+ with open(path / "vad.pkl", "wb") as f:
82
+ pickle.dump({"speech_intervals": speech_intervals,
83
+ "is_complete": is_complete}, f)
84
+
85
+ # Lightweight metadata (JSON-safe types only)
86
  meta = {
 
87
  "intervals": intervals,
88
  "model_name": model_name,
89
  "intervals_hash": _intervals_hash(intervals),
 
90
  }
91
+ with open(path / "metadata.json", "w") as f:
92
+ json.dump(meta, f)
93
+
94
+ # Timestamp file for cheap expiry checks during sweep
95
+ (path / "created_at").write_text(str(time.time()))
96
+
97
  return audio_id
98
 
99
 
 
104
  path = _session_dir(audio_id)
105
  if not path.exists():
106
  return None
107
+
108
+ ts_file = path / "created_at"
109
+ if not ts_file.exists() or _is_expired(float(ts_file.read_text())):
110
  shutil.rmtree(path, ignore_errors=True)
111
  return None
112
 
113
  audio = np.load(path / "audio.npy")
114
+
115
+ with open(path / "vad.pkl", "rb") as f:
116
+ vad = pickle.load(f)
117
+
118
+ with open(path / "metadata.json") as f:
119
+ meta = json.load(f)
120
 
121
  return {
122
  "audio": audio,
123
+ "speech_intervals": vad["speech_intervals"],
124
+ "is_complete": vad["is_complete"],
125
  "intervals": meta["intervals"],
126
  "model_name": meta["model_name"],
127
  "intervals_hash": meta.get("intervals_hash", ""),
 
132
  def update_session(audio_id, *, intervals=None, model_name=None):
133
  """Update mutable session fields (intervals, model_name)."""
134
  path = _session_dir(audio_id)
135
+ meta_path = path / "metadata.json"
136
+ if not meta_path.exists():
137
  return
138
+ with open(meta_path) as f:
139
+ meta = json.load(f)
140
  if intervals is not None:
141
  meta["intervals"] = intervals
142
  meta["intervals_hash"] = _intervals_hash(intervals)
143
  if model_name is not None:
144
  meta["model_name"] = model_name
145
+ tmp = path / "metadata.tmp"
146
+ with open(tmp, "w") as f:
147
+ json.dump(meta, f)
148
+ os.replace(tmp, meta_path)
149
 
150
 
151
  # ---------------------------------------------------------------------------