pgits commited on
Commit
19eacc4
·
verified ·
1 Parent(s): 736c54a

Migrate working code: Deploy app.py v1.0.1 to correct Space

Browse files
Files changed (1) hide show
  1. app.py +32 -249
app.py CHANGED
@@ -1,260 +1,43 @@
1
- import asyncio
2
- import json
3
- import logging
4
- import os
5
- import tempfile
6
  import time
7
- from typing import Optional
8
 
9
- import librosa
10
- import numpy as np
11
- import torch
12
- import uvicorn
13
- from fastapi import FastAPI, File, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
14
- from fastapi.responses import JSONResponse
15
- from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
16
 
17
- # Configure logging
18
- logging.basicConfig(level=logging.INFO)
19
- logger = logging.getLogger(__name__)
20
-
21
- # Global variables for model
22
- processor = None
23
- model = None
24
- device = None
25
-
26
- app = FastAPI(
27
- title="STT GPU Service Python v4",
28
- description="Real-time Speech-to-Text service using Kyutai Moshi model",
29
- version="1.0.0"
30
- )
31
-
32
- class ConnectionManager:
33
- def __init__(self):
34
- self.active_connections: list[WebSocket] = []
35
- self.max_connections = 2
36
-
37
- async def connect(self, websocket: WebSocket) -> bool:
38
- if len(self.active_connections) >= self.max_connections:
39
- return False
40
- await websocket.accept()
41
- self.active_connections.append(websocket)
42
- logger.info(f"WebSocket connected. Active connections: {len(self.active_connections)}")
43
- return True
44
-
45
- def disconnect(self, websocket: WebSocket):
46
- if websocket in self.active_connections:
47
- self.active_connections.remove(websocket)
48
- logger.info(f"WebSocket disconnected. Active connections: {len(self.active_connections)}")
49
-
50
- manager = ConnectionManager()
51
-
52
- def load_model():
53
- """Load the STT model and processor"""
54
- global processor, model, device
55
-
56
- try:
57
- device = "cuda" if torch.cuda.is_available() else "cpu"
58
- logger.info(f"Loading model on device: {device}")
59
-
60
- model_id = "kyutai/stt-1b-en_fr"
61
- processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
62
- model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
63
- model_id,
64
- device_map=device,
65
- torch_dtype=torch.float16 if device == "cuda" else torch.float32
66
- )
67
-
68
- logger.info("Model loaded successfully")
69
- return True
70
- except Exception as e:
71
- logger.error(f"Error loading model: {e}")
72
- return False
73
-
74
- def transcribe_audio_chunk(audio_data: np.ndarray, sample_rate: int = 16000) -> dict:
75
- """Transcribe audio chunk and return result with timestamps"""
76
- try:
77
- if processor is None or model is None:
78
- raise Exception("Model not loaded")
79
-
80
- # Process audio
81
- inputs = processor(
82
- audio_data,
83
- sampling_rate=sample_rate,
84
- return_tensors="pt"
85
- ).to(device)
86
-
87
- # Generate transcription
88
- with torch.no_grad():
89
- generated_ids = model.generate(
90
- **inputs,
91
- max_new_tokens=256,
92
- return_timestamps=True
93
- )
94
-
95
- # Decode with timestamps
96
- result = processor.batch_decode(
97
- generated_ids,
98
- skip_special_tokens=True,
99
- return_timestamps=True
100
- )[0]
101
-
102
- return {
103
- "text": result.get("text", ""),
104
- "chunks": result.get("chunks", []),
105
- "timestamp": time.time()
106
- }
107
- except Exception as e:
108
- logger.error(f"Transcription error: {e}")
109
- return {"text": "", "chunks": [], "error": str(e), "timestamp": time.time()}
110
-
111
- @app.on_event("startup")
112
- async def startup_event():
113
- """Load model on startup"""
114
- logger.info("Starting STT GPU Service Python v4...")
115
- success = load_model()
116
- if not success:
117
- logger.error("Failed to load model on startup")
118
- raise Exception("Model loading failed")
119
-
120
- @app.get("/health")
121
- async def health_check():
122
- """Health check endpoint"""
123
- gpu_available = torch.cuda.is_available()
124
- model_loaded = processor is not None and model is not None
125
-
126
  return {
127
- "status": "healthy" if model_loaded else "unhealthy",
128
- "model_loaded": model_loaded,
129
- "gpu_available": gpu_available,
130
- "device": device,
131
- "active_connections": len(manager.active_connections),
132
- "max_connections": manager.max_connections,
133
- "timestamp": time.time()
134
  }
135
 
136
- @app.post("/transcribe")
137
- async def transcribe_file(audio_file: UploadFile = File(...)):
138
- """REST endpoint for file upload transcription"""
139
- try:
140
- if processor is None or model is None:
141
- raise HTTPException(status_code=503, detail="Model not loaded")
142
-
143
- # Save uploaded file temporarily
144
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
145
- content = await audio_file.read()
146
- tmp_file.write(content)
147
- tmp_file_path = tmp_file.name
148
-
149
- try:
150
- # Load audio file
151
- audio_data, sample_rate = librosa.load(tmp_file_path, sr=16000, mono=True)
152
-
153
- # Transcribe
154
- result = transcribe_audio_chunk(audio_data, sample_rate)
155
-
156
- return JSONResponse(content={
157
- "filename": audio_file.filename,
158
- "transcription": result["text"],
159
- "chunks": result["chunks"],
160
- "timestamp": result["timestamp"],
161
- "error": result.get("error")
162
- })
163
- finally:
164
- # Clean up temp file
165
- os.unlink(tmp_file_path)
166
-
167
- except Exception as e:
168
- logger.error(f"File transcription error: {e}")
169
- raise HTTPException(status_code=500, detail=str(e))
170
 
171
- @app.websocket("/ws/stream")
172
- async def websocket_endpoint(websocket: WebSocket):
173
- """WebSocket endpoint for streaming audio transcription"""
 
174
 
175
- # Check connection limit
176
- if not await manager.connect(websocket):
177
- await websocket.close(code=1013, reason="Maximum connections reached")
178
- return
179
 
180
- try:
181
- await websocket.send_text(json.dumps({
182
- "type": "connection_established",
183
- "message": "Connected to STT service",
184
- "expected_format": "16kHz mono PCM audio in 80ms chunks",
185
- "chunk_size_bytes": 2560 # 80ms * 16000 Hz * 2 bytes/sample
186
- }))
187
-
188
- while True:
189
- # Receive audio data
190
- try:
191
- # Set timeout for receiving data (80ms + buffer)
192
- data = await asyncio.wait_for(websocket.receive_bytes(), timeout=0.2)
193
- except asyncio.TimeoutError:
194
- # Send keepalive
195
- await websocket.send_text(json.dumps({
196
- "type": "keepalive",
197
- "timestamp": time.time()
198
- }))
199
- continue
200
-
201
- # Convert bytes to numpy array (assuming 16-bit PCM)
202
- audio_chunk = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
203
-
204
- # Transcribe chunk
205
- result = transcribe_audio_chunk(audio_chunk, 16000)
206
-
207
- # Send result back
208
- response = {
209
- "type": "transcription",
210
- "text": result["text"],
211
- "chunks": result["chunks"],
212
- "timestamp": result["timestamp"],
213
- "chunk_duration_ms": 80,
214
- "error": result.get("error")
215
- }
216
-
217
- await websocket.send_text(json.dumps(response))
218
-
219
- except WebSocketDisconnect:
220
- logger.info("WebSocket disconnected normally")
221
- except Exception as e:
222
- logger.error(f"WebSocket error: {e}")
223
- try:
224
- await websocket.send_text(json.dumps({
225
- "type": "error",
226
- "message": str(e),
227
- "timestamp": time.time()
228
- }))
229
- except:
230
- pass
231
- finally:
232
- manager.disconnect(websocket)
233
-
234
- @app.get("/")
235
- async def root():
236
- """Root endpoint with service information"""
237
- return {
238
- "service": "STT GPU Service Python v4",
239
- "model": "kyutai/stt-1b-en_fr",
240
- "endpoints": {
241
- "health": "/health",
242
- "transcribe": "/transcribe (POST with audio file)",
243
- "stream": "/ws/stream (WebSocket for real-time)"
244
- },
245
- "streaming_info": {
246
- "chunk_size": "80ms",
247
- "sample_rate": "16kHz",
248
- "format": "mono PCM",
249
- "latency": "~200ms"
250
- }
251
- }
252
 
253
  if __name__ == "__main__":
254
- uvicorn.run(
255
- "app:app",
256
- host="0.0.0.0",
257
- port=7860,
258
- log_level="info",
259
- access_log=True
260
- )
 
1
+ import gradio as gr
 
 
 
 
2
  import time
 
3
 
4
+ # Semantic versioning - updated for correct Space
5
+ VERSION = "1.0.1"
6
+ COMMIT_SHA = "TBD" # Will be updated after push
 
 
 
 
7
 
8
+ def health_check():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  return {
10
+ "status": "healthy",
11
+ "timestamp": time.time(),
12
+ "version": VERSION,
13
+ "commit_sha": COMMIT_SHA,
14
+ "message": "STT Service - Ready for model integration",
15
+ "space_name": "stt-gpu-service-python-v4"
 
16
  }
17
 
18
+ def placeholder_transcribe(audio):
19
+ if audio is None:
20
+ return "No audio provided"
21
+ return f"Placeholder: Audio received (type: {type(audio)}) - STT model integration pending"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Create interface with version display
24
+ with gr.Blocks(title="STT GPU Service Python v4") as demo:
25
+ gr.Markdown("# 🎙️ STT GPU Service Python v4")
26
+ gr.Markdown("Working deployment! Ready for STT model integration.")
27
 
28
+ with gr.Tab("Health Check"):
29
+ health_btn = gr.Button("Check Health")
30
+ health_output = gr.JSON()
31
+ health_btn.click(health_check, outputs=health_output)
32
 
33
+ with gr.Tab("Audio Test"):
34
+ audio_input = gr.Audio(type="numpy")
35
+ transcribe_btn = gr.Button("Test Transcribe")
36
+ output_text = gr.Textbox()
37
+ transcribe_btn.click(placeholder_transcribe, inputs=audio_input, outputs=output_text)
38
+
39
+ # Version display in small text at bottom as requested
40
+ gr.Markdown(f"<small>v{VERSION} (SHA: {COMMIT_SHA})</small>", elem_id="version-info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  if __name__ == "__main__":
43
+ demo.launch()