pgits commited on
Commit
4020b5c
·
verified ·
1 Parent(s): 6ba8fa9

Fix v1.1.1: app.py - transformers>=4.53.0 + 24kHz audio support

Browse files
Files changed (1) hide show
  1. app.py +29 -16
app.py CHANGED
@@ -14,7 +14,7 @@ from fastapi.responses import HTMLResponse
14
  import uvicorn
15
 
16
  # Version tracking
17
- VERSION = "1.1.0"
18
  COMMIT_SHA = "TBD"
19
 
20
  # Configure logging
@@ -40,9 +40,13 @@ async def load_model():
40
  from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
41
  model_id = "kyutai/stt-1b-en_fr"
42
 
 
43
  processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
 
 
44
  model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id).to(device)
45
- logger.info(f"Model {model_id} loaded successfully")
 
46
 
47
  except Exception as model_error:
48
  logger.warning(f"Could not load actual model: {model_error}")
@@ -55,15 +59,20 @@ async def load_model():
55
  model = "mock"
56
  processor = "mock"
57
 
58
- def transcribe_audio(audio_data: np.ndarray, sample_rate: int = 16000) -> str:
59
- """Transcribe audio data"""
60
  try:
61
  if model == "mock":
62
  # Mock transcription for development
63
- return f"Mock transcription: {len(audio_data)} samples at {sample_rate}Hz"
 
 
 
 
 
 
64
 
65
- # Real transcription
66
- inputs = processor(audio_data, sampling_rate=sample_rate, return_tensors="pt")
67
  inputs = {k: v.to(device) for k, v in inputs.items()}
68
 
69
  with torch.no_grad():
@@ -79,7 +88,7 @@ def transcribe_audio(audio_data: np.ndarray, sample_rate: int = 16000) -> str:
79
  # FastAPI app
80
  app = FastAPI(
81
  title="STT GPU Service Python v4",
82
- description="Real-time WebSocket STT streaming with kyutai/stt-1b-en_fr",
83
  version=VERSION
84
  )
85
 
@@ -99,7 +108,8 @@ async def health_check():
99
  "message": "STT WebSocket Service - Real-time streaming ready",
100
  "space_name": "stt-gpu-service-python-v4",
101
  "model_loaded": model is not None,
102
- "device": str(device) if device else "unknown"
 
103
  }
104
 
105
  @app.get("/", response_class=HTMLResponse)
@@ -123,13 +133,14 @@ async def get_index():
123
  <body>
124
  <div class="container">
125
  <h1>🎙️ STT GPU Service Python v4</h1>
126
- <p>Real-time WebSocket speech transcription service</p>
127
 
128
  <div class="status">
129
  <h3>WebSocket Streaming Test</h3>
130
  <button onclick="startWebSocket()">Connect WebSocket</button>
131
  <button onclick="stopWebSocket()" disabled id="stopBtn">Disconnect</button>
132
  <p>Status: <span id="wsStatus">Disconnected</span></p>
 
133
  </div>
134
 
135
  <div id="output">
@@ -158,7 +169,7 @@ async def get_index():
158
  // Send test message
159
  ws.send(JSON.stringify({{
160
  type: 'audio_chunk',
161
- data: 'test_audio_data',
162
  timestamp: Date.now()
163
  }}));
164
  }};
@@ -203,7 +214,8 @@ async def websocket_endpoint(websocket: WebSocket):
203
  "status": "connected",
204
  "message": "STT WebSocket ready for audio chunks",
205
  "chunk_size_ms": 80,
206
- "expected_sample_rate": 16000
 
207
  })
208
 
209
  while True:
@@ -212,15 +224,15 @@ async def websocket_endpoint(websocket: WebSocket):
212
 
213
  if data.get("type") == "audio_chunk":
214
  try:
215
- # Process 80ms audio chunk
216
  # In real implementation, you would:
217
  # 1. Decode base64 audio data
218
- # 2. Convert to numpy array
219
  # 3. Process with STT model
220
  # 4. Return transcription
221
 
222
  # For now, mock processing
223
- transcription = f"Mock transcription for chunk at {data.get('timestamp', 'unknown')}"
224
 
225
  # Send transcription result
226
  await websocket.send_json({
@@ -262,7 +274,8 @@ async def api_transcribe(audio_file: Optional[str] = None):
262
  "transcription": f"REST API transcription result for: {audio_file[:50]}...",
263
  "timestamp": time.time(),
264
  "version": VERSION,
265
- "method": "REST"
 
266
  }
267
 
268
  return result
 
14
  import uvicorn
15
 
16
  # Version tracking
17
+ VERSION = "1.1.1"
18
  COMMIT_SHA = "TBD"
19
 
20
  # Configure logging
 
40
  from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
41
  model_id = "kyutai/stt-1b-en_fr"
42
 
43
+ logger.info(f"Loading processor from {model_id}...")
44
  processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
45
+
46
+ logger.info(f"Loading model from {model_id}...")
47
  model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id).to(device)
48
+
49
+ logger.info(f"Model {model_id} loaded successfully on {device}")
50
 
51
  except Exception as model_error:
52
  logger.warning(f"Could not load actual model: {model_error}")
 
59
  model = "mock"
60
  processor = "mock"
61
 
62
+ def transcribe_audio(audio_data: np.ndarray, sample_rate: int = 24000) -> str:
63
+ """Transcribe audio data - expects 24kHz audio for Kyutai STT"""
64
  try:
65
  if model == "mock":
66
  # Mock transcription for development
67
+ duration = len(audio_data) / sample_rate
68
+ return f"Mock transcription: {duration:.2f}s audio at {sample_rate}Hz ({len(audio_data)} samples)"
69
+
70
+ # Real transcription - Kyutai STT expects 24kHz
71
+ if sample_rate != 24000:
72
+ logger.info(f"Resampling from {sample_rate}Hz to 24000Hz")
73
+ audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=24000)
74
 
75
+ inputs = processor(audio_data, sampling_rate=24000, return_tensors="pt")
 
76
  inputs = {k: v.to(device) for k, v in inputs.items()}
77
 
78
  with torch.no_grad():
 
88
  # FastAPI app
89
  app = FastAPI(
90
  title="STT GPU Service Python v4",
91
+ description="Real-time WebSocket STT streaming with kyutai/stt-1b-en_fr (24kHz)",
92
  version=VERSION
93
  )
94
 
 
108
  "message": "STT WebSocket Service - Real-time streaming ready",
109
  "space_name": "stt-gpu-service-python-v4",
110
  "model_loaded": model is not None,
111
+ "device": str(device) if device else "unknown",
112
+ "expected_sample_rate": "24000Hz"
113
  }
114
 
115
  @app.get("/", response_class=HTMLResponse)
 
133
  <body>
134
  <div class="container">
135
  <h1>🎙️ STT GPU Service Python v4</h1>
136
+ <p>Real-time WebSocket speech transcription service (24kHz audio)</p>
137
 
138
  <div class="status">
139
  <h3>WebSocket Streaming Test</h3>
140
  <button onclick="startWebSocket()">Connect WebSocket</button>
141
  <button onclick="stopWebSocket()" disabled id="stopBtn">Disconnect</button>
142
  <p>Status: <span id="wsStatus">Disconnected</span></p>
143
+ <p><small>Expected: 24kHz audio chunks (80ms = ~1920 samples)</small></p>
144
  </div>
145
 
146
  <div id="output">
 
169
  // Send test message
170
  ws.send(JSON.stringify({{
171
  type: 'audio_chunk',
172
+ data: 'test_audio_data_24khz',
173
  timestamp: Date.now()
174
  }}));
175
  }};
 
214
  "status": "connected",
215
  "message": "STT WebSocket ready for audio chunks",
216
  "chunk_size_ms": 80,
217
+ "expected_sample_rate": 24000,
218
+ "expected_chunk_samples": 1920 # 80ms at 24kHz = 1920 samples
219
  })
220
 
221
  while True:
 
224
 
225
  if data.get("type") == "audio_chunk":
226
  try:
227
+ # Process 80ms audio chunk (1920 samples at 24kHz)
228
  # In real implementation, you would:
229
  # 1. Decode base64 audio data
230
+ # 2. Convert to numpy array (24kHz)
231
  # 3. Process with STT model
232
  # 4. Return transcription
233
 
234
  # For now, mock processing
235
+ transcription = f"Mock transcription for 24kHz chunk at {data.get('timestamp', 'unknown')}"
236
 
237
  # Send transcription result
238
  await websocket.send_json({
 
274
  "transcription": f"REST API transcription result for: {audio_file[:50]}...",
275
  "timestamp": time.time(),
276
  "version": VERSION,
277
+ "method": "REST",
278
+ "expected_sample_rate": "24kHz"
279
  }
280
 
281
  return result