Spaces:
Running
Running
| """Batch audio processing service. | |
| Phase 10: Upload a set of audio files → batch-generate SOAP notes. | |
| Tracks progress and supports inter-session linking for follow-up visits. | |
| """ | |
| import asyncio | |
| import logging | |
| import time | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| from typing import Any, Dict, List, Optional | |
| from app.config import settings | |
| logger = logging.getLogger(__name__) | |
| class BatchStatus(str, Enum): | |
| QUEUED = "queued" | |
| PROCESSING = "processing" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| CANCELLED = "cancelled" | |
| class BatchItem: | |
| item_id: str | |
| filename: str | |
| status: BatchStatus = BatchStatus.QUEUED | |
| result: Optional[Dict[str, Any]] = None | |
| error: Optional[str] = None | |
| started_at: Optional[float] = None | |
| completed_at: Optional[float] = None | |
| class BatchJob: | |
| batch_id: str | |
| items: List[BatchItem] = field(default_factory=list) | |
| status: BatchStatus = BatchStatus.QUEUED | |
| created_at: float = field(default_factory=time.time) | |
| completed_at: Optional[float] = None | |
| linked_session_id: Optional[str] = None # For inter-session linking | |
| def progress(self) -> Dict[str, int]: | |
| total = len(self.items) | |
| completed = sum(1 for i in self.items if i.status == BatchStatus.COMPLETED) | |
| failed = sum(1 for i in self.items if i.status == BatchStatus.FAILED) | |
| return { | |
| "total": total, | |
| "completed": completed, | |
| "failed": failed, | |
| "remaining": total - completed - failed, | |
| "percent": round(completed / total * 100, 1) if total > 0 else 0, | |
| } | |
| class BatchProcessor: | |
| """Manages batch audio processing jobs.""" | |
| def __init__(self, max_concurrent: int = 2): | |
| self._jobs: Dict[str, BatchJob] = {} | |
| self._semaphore = asyncio.Semaphore(max_concurrent) | |
| def create_job( | |
| self, | |
| filenames: List[str], | |
| linked_session_id: Optional[str] = None, | |
| ) -> BatchJob: | |
| """Create a new batch job for the given files.""" | |
| batch_id = str(uuid.uuid4()) | |
| items = [ | |
| BatchItem(item_id=str(uuid.uuid4()), filename=fn) | |
| for fn in filenames | |
| ] | |
| job = BatchJob( | |
| batch_id=batch_id, | |
| items=items, | |
| linked_session_id=linked_session_id, | |
| ) | |
| self._jobs[batch_id] = job | |
| return job | |
| async def process_job( | |
| self, | |
| batch_id: str, | |
| process_fn, | |
| ) -> BatchJob: | |
| """Process all items in a batch job. | |
| Args: | |
| batch_id: The batch job ID. | |
| process_fn: Async callable(filename) -> Dict with SOAP result. | |
| """ | |
| job = self._jobs.get(batch_id) | |
| if not job: | |
| raise ValueError(f"Batch job {batch_id} not found") | |
| job.status = BatchStatus.PROCESSING | |
| async def _process_item(item: BatchItem): | |
| async with self._semaphore: | |
| item.status = BatchStatus.PROCESSING | |
| item.started_at = time.time() | |
| try: | |
| item.result = await process_fn(item.filename) | |
| item.status = BatchStatus.COMPLETED | |
| except Exception as e: | |
| item.status = BatchStatus.FAILED | |
| item.error = str(e) | |
| logger.error("Batch item %s failed: %s", item.filename, e) | |
| finally: | |
| item.completed_at = time.time() | |
| tasks = [_process_item(item) for item in job.items] | |
| await asyncio.gather(*tasks, return_exceptions=True) | |
| all_done = all( | |
| i.status in (BatchStatus.COMPLETED, BatchStatus.FAILED) for i in job.items | |
| ) | |
| if all_done: | |
| any_success = any(i.status == BatchStatus.COMPLETED for i in job.items) | |
| job.status = BatchStatus.COMPLETED if any_success else BatchStatus.FAILED | |
| job.completed_at = time.time() | |
| return job | |
| def get_job(self, batch_id: str) -> Optional[BatchJob]: | |
| return self._jobs.get(batch_id) | |
| def cancel_job(self, batch_id: str) -> bool: | |
| job = self._jobs.get(batch_id) | |
| if not job or job.status in (BatchStatus.COMPLETED, BatchStatus.CANCELLED): | |
| return False | |
| job.status = BatchStatus.CANCELLED | |
| for item in job.items: | |
| if item.status == BatchStatus.QUEUED: | |
| item.status = BatchStatus.CANCELLED | |
| return True | |
| def list_jobs(self, limit: int = 20) -> List[Dict[str, Any]]: | |
| jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True) | |
| return [ | |
| { | |
| "batch_id": j.batch_id, | |
| "status": j.status, | |
| "progress": j.progress, | |
| "created_at": j.created_at, | |
| "linked_session_id": j.linked_session_id, | |
| } | |
| for j in jobs[:limit] | |
| ] | |
| # ===================================================== | |
| # Inter-Session Linking | |
| # ===================================================== | |
| class SessionLinker: | |
| """Links follow-up visits to original sessions for longitudinal view.""" | |
| def __init__(self): | |
| self._links: Dict[str, List[str]] = {} # parent_id -> [child_ids] | |
| self._reverse: Dict[str, str] = {} # child_id -> parent_id | |
| def link(self, parent_session_id: str, child_session_id: str) -> None: | |
| if parent_session_id not in self._links: | |
| self._links[parent_session_id] = [] | |
| if child_session_id not in self._links[parent_session_id]: | |
| self._links[parent_session_id].append(child_session_id) | |
| self._reverse[child_session_id] = parent_session_id | |
| def get_chain(self, session_id: str) -> List[str]: | |
| """Get the full session chain (root → ... → current).""" | |
| # Walk up to root | |
| root = session_id | |
| while root in self._reverse: | |
| root = self._reverse[root] | |
| # Walk down from root | |
| chain = [root] | |
| current = root | |
| while current in self._links: | |
| children = self._links[current] | |
| if not children: | |
| break | |
| chain.append(children[-1]) # Latest follow-up | |
| current = children[-1] | |
| return chain | |
| def get_follow_ups(self, session_id: str) -> List[str]: | |
| return self._links.get(session_id, []) | |
| def get_parent(self, session_id: str) -> Optional[str]: | |
| return self._reverse.get(session_id) | |
| # ===================================================== | |
| # Audio Quality Check | |
| # ===================================================== | |
| def check_audio_quality(audio_data, sample_rate: int = 16000) -> Dict[str, Any]: | |
| """Check audio quality metrics before processing. | |
| Returns SNR estimate and quality warnings. | |
| """ | |
| import numpy as np | |
| if isinstance(audio_data, bytes): | |
| audio = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0 | |
| else: | |
| audio = np.asarray(audio_data, dtype=np.float32) | |
| if len(audio) == 0: | |
| return {"quality": "empty", "snr_db": 0, "warnings": ["No audio data"]} | |
| # RMS energy | |
| rms = np.sqrt(np.mean(audio ** 2)) | |
| rms_db = 20 * np.log10(max(rms, 1e-10)) | |
| # Simple SNR estimate: compare top 10% energy frames vs bottom 10% | |
| frame_size = int(sample_rate * 0.025) # 25ms frames | |
| n_frames = max(1, len(audio) // frame_size) | |
| frame_energies = [] | |
| for i in range(n_frames): | |
| frame = audio[i * frame_size : (i + 1) * frame_size] | |
| frame_energies.append(np.mean(frame ** 2)) | |
| frame_energies.sort() | |
| noise_floor = np.mean(frame_energies[: max(1, n_frames // 10)]) | |
| signal_level = np.mean(frame_energies[-(max(1, n_frames // 10)) :]) | |
| snr_db = 10 * np.log10(max(signal_level / max(noise_floor, 1e-10), 1e-10)) | |
| # Duration | |
| duration_s = len(audio) / sample_rate | |
| warnings = [] | |
| if snr_db < 10: | |
| warnings.append(f"Low SNR ({snr_db:.1f} dB) — noisy environment detected") | |
| if rms_db < -40: | |
| warnings.append(f"Very quiet audio ({rms_db:.1f} dB RMS)") | |
| if duration_s < 1.0: | |
| warnings.append(f"Very short audio ({duration_s:.1f}s)") | |
| if duration_s > settings.max_audio_duration_seconds: | |
| warnings.append( | |
| f"Audio exceeds max duration ({duration_s:.0f}s > {settings.max_audio_duration_seconds}s)" | |
| ) | |
| quality = "good" | |
| if warnings: | |
| quality = "poor" if snr_db < 5 else "fair" | |
| return { | |
| "quality": quality, | |
| "snr_db": round(snr_db, 1), | |
| "rms_db": round(rms_db, 1), | |
| "duration_seconds": round(duration_s, 2), | |
| "warnings": warnings, | |
| } | |
| # ===================================================== | |
| # Singletons | |
| # ===================================================== | |
| _batch_processor: Optional[BatchProcessor] = None | |
| _session_linker: Optional[SessionLinker] = None | |
| def get_batch_processor() -> BatchProcessor: | |
| global _batch_processor | |
| if _batch_processor is None: | |
| _batch_processor = BatchProcessor( | |
| max_concurrent=settings.queue_max_concurrent_inferences | |
| ) | |
| return _batch_processor | |
| def get_session_linker() -> SessionLinker: | |
| global _session_linker | |
| if _session_linker is None: | |
| _session_linker = SessionLinker() | |
| return _session_linker | |