Spaces:
Sleeping
Sleeping
| import re | |
| import os | |
| import numpy as np | |
| import torch | |
| import librosa | |
| import librosa.display | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from torchvision import transforms | |
| import whisper | |
| # Force non-interactive backend for server environments | |
| matplotlib.use('Agg') | |
| # ========================================== | |
| # 0. Segmentation CSV Parser | |
| # ========================================== | |
| def parse_segmentation_csv(csv_content: bytes) -> list: | |
| """ | |
| Parse segmentation CSV to extract PAR speaker intervals. | |
| CSV format: speaker,start_ms,end_ms | |
| Returns list of (start_ms, end_ms) tuples for PAR speaker only. | |
| """ | |
| intervals = [] | |
| try: | |
| lines = csv_content.decode('utf-8', errors='replace').strip().split('\n') | |
| for i, line in enumerate(lines): | |
| if i == 0 and 'speaker' in line.lower(): | |
| continue # Skip header | |
| parts = line.strip().split(',') | |
| if len(parts) >= 3 and parts[0].strip().upper() == 'PAR': | |
| start_ms = int(parts[1].strip()) | |
| end_ms = int(parts[2].strip()) | |
| intervals.append((start_ms, end_ms)) | |
| except Exception as e: | |
| print(f"Error parsing segmentation CSV: {e}") | |
| return intervals | |
| # ========================================== | |
| # 1. Linguistic Feature Extractor | |
| # ========================================== | |
| class LinguisticFeatureExtractor: | |
| def __init__(self): | |
| self.patterns = { | |
| 'fillers': re.compile(r'&-([a-z]+)', re.IGNORECASE), | |
| 'repetition': re.compile(r'\[/+\]'), | |
| 'retracing': re.compile(r'\[//\]'), | |
| 'incomplete': re.compile(r'\+[\./]+'), | |
| 'errors': re.compile(r'\[\*.*?\]'), | |
| 'pauses': re.compile(r'\(\.+\)') | |
| } | |
| def clean_for_bert(self, raw_text): | |
| text = re.sub(r'^\*PAR:\s+', '', raw_text) | |
| text = re.sub(r'\x15\d+_\d+\x15', '', text) | |
| text = re.sub(r'<|>', '', text) | |
| text = re.sub(r'\[.*?\]', '', text) | |
| text = re.sub(r'\(\.+\)', '[PAUSE]', text) | |
| text = text.replace('_', ' ') | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| if text.endswith('[PAUSE]'): | |
| text = text[:-7].strip() | |
| return text | |
| def get_features(self, raw_text): | |
| stats = { | |
| 'filler_count': len(self.patterns['fillers'].findall(raw_text)), | |
| 'repetition_count': len(self.patterns['repetition'].findall(raw_text)), | |
| 'retracing_count': len(self.patterns['retracing'].findall(raw_text)), | |
| 'incomplete_count': len(self.patterns['incomplete'].findall(raw_text)), | |
| 'error_count': len(self.patterns['errors'].findall(raw_text)), | |
| 'pause_count': len(self.patterns['pauses'].findall(raw_text)) | |
| } | |
| clean_for_stats = re.sub(r'\[.*?\]', '', raw_text) | |
| clean_for_stats = re.sub(r'&-([a-z]+)', '', clean_for_stats) | |
| clean_for_stats = re.sub(r'[^\w\s]', '', clean_for_stats) | |
| words = clean_for_stats.lower().split() | |
| stats['word_count'] = len(words) | |
| return stats | |
| def get_feature_vector(self, raw_text): | |
| stats = self.get_features(raw_text) | |
| n = stats['word_count'] if stats['word_count'] > 0 else 1 | |
| # Calculate TTR (Type-Token Ratio) | |
| clean_for_stats = re.sub(r'\[.*?\]', '', raw_text) | |
| clean_for_stats = re.sub(r'&-([a-z]+)', '', clean_for_stats) | |
| clean_for_stats = re.sub(r'[^\w\s]', '', clean_for_stats) | |
| words = clean_for_stats.lower().split() | |
| ttr = (len(set(words)) / n) if n > 0 else 0.0 | |
| return np.array([ | |
| ttr, | |
| stats['filler_count'] / n, | |
| stats['repetition_count'] / n, | |
| stats['retracing_count'] / n, | |
| stats['error_count'] / n, | |
| stats['pause_count'] / n | |
| ], dtype=np.float32) | |
| def extract_key_segments(self, text, max_segments=3): | |
| """ | |
| Extract sentences with highest linguistic marker density. | |
| Returns list of {text, marker_count} sorted by marker count. | |
| """ | |
| # Split into segments using multiple delimiters: | |
| # - Sentence endings (.?!) | |
| # - Newlines | |
| # - Timestamp markers (common in CHA files) | |
| segments = re.split(r'[.?!\n]+|\x15\d+_\d+\x15', text) | |
| segments = [s.strip() for s in segments if s.strip()] | |
| # If no segments found, try splitting by long spaces or just use the whole text | |
| if not segments and text.strip(): | |
| # Split by multiple spaces or use chunks of ~50 words | |
| words = text.split() | |
| if len(words) > 15: | |
| # Create chunks of ~15 words each | |
| chunk_size = 15 | |
| segments = [' '.join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)] | |
| else: | |
| segments = [text.strip()] | |
| scored = [] | |
| for sent in segments: | |
| # Count markers in each segment | |
| count = 0 | |
| count += len(self.patterns['fillers'].findall(sent)) | |
| count += len(self.patterns['repetition'].findall(sent)) | |
| count += len(self.patterns['retracing'].findall(sent)) | |
| count += len(self.patterns['pauses'].findall(sent)) | |
| count += len(self.patterns['errors'].findall(sent)) | |
| # Also count [PAUSE] tokens from ASR | |
| count += sent.count('[PAUSE]') | |
| count += sent.count('[/]') | |
| if len(sent) > 10: # Skip very short fragments | |
| scored.append({"text": sent, "marker_count": count}) | |
| # Sort by marker count descending | |
| scored.sort(key=lambda x: x['marker_count'], reverse=True) | |
| return scored[:max_segments] | |
| # ========================================== | |
| # 2. Audio Processor | |
| # ========================================== | |
| class AudioProcessor: | |
| def __init__(self): | |
| self.vit_transforms = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def create_spectrogram_tensor(self, audio_path, intervals=None): | |
| """ | |
| Generates spectrogram image and transforms it to Tensor. | |
| """ | |
| try: | |
| fig = plt.figure(figsize=(2.24, 2.24), dpi=100) | |
| ax = fig.add_subplot(1, 1, 1) | |
| fig.subplots_adjust(left=0, right=1, bottom=0, top=1) | |
| if intervals: | |
| # Load full audio then slice based on timestamps | |
| y, sr = librosa.load(audio_path, sr=None) | |
| clips = [] | |
| for start_ms, end_ms in intervals: | |
| start_sample = int(start_ms * sr / 1000) | |
| end_sample = int(end_ms * sr / 1000) | |
| if end_sample > len(y): end_sample = len(y) | |
| if start_sample < len(y): | |
| clips.append(y[start_sample:end_sample]) | |
| if clips: | |
| y = np.concatenate(clips) | |
| else: | |
| y = np.zeros(int(sr*30)) | |
| # Limit to 30s | |
| if len(y) > 30 * sr: | |
| y = y[:30 * sr] | |
| else: | |
| y, sr = librosa.load(audio_path, duration=30) | |
| ms = librosa.feature.melspectrogram(y=y, sr=sr) | |
| log_ms = librosa.power_to_db(ms, ref=np.max) | |
| librosa.display.specshow(log_ms, sr=sr, ax=ax) | |
| # Save to buffer instead of file | |
| from io import BytesIO | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png') | |
| plt.close(fig) | |
| buf.seek(0) | |
| image = Image.open(buf).convert('RGB') | |
| return self.vit_transforms(image).unsqueeze(0) | |
| except Exception as e: | |
| print(f"Spectrogram creation failed: {e}") | |
| return torch.zeros((1, 3, 224, 224)) | |
| def create_spectrogram_base64(self, audio_path, intervals=None): | |
| """ | |
| Generates spectrogram and returns as base64 string for visualization. | |
| """ | |
| import base64 | |
| from io import BytesIO | |
| try: | |
| fig = plt.figure(figsize=(4, 3), dpi=100) | |
| ax = fig.add_subplot(1, 1, 1) | |
| if intervals: | |
| y, sr = librosa.load(audio_path, sr=None) | |
| clips = [] | |
| for start_ms, end_ms in intervals: | |
| start_sample = int(start_ms * sr / 1000) | |
| end_sample = int(end_ms * sr / 1000) | |
| if end_sample > len(y): end_sample = len(y) | |
| if start_sample < len(y): | |
| clips.append(y[start_sample:end_sample]) | |
| if clips: | |
| y = np.concatenate(clips) | |
| else: | |
| y = np.zeros(int(sr*30)) | |
| if len(y) > 30 * sr: | |
| y = y[:30 * sr] | |
| else: | |
| y, sr = librosa.load(audio_path, duration=30) | |
| ms = librosa.feature.melspectrogram(y=y, sr=sr) | |
| log_ms = librosa.power_to_db(ms, ref=np.max) | |
| img = librosa.display.specshow(log_ms, sr=sr, x_axis='time', y_axis='mel', ax=ax) | |
| fig.colorbar(img, ax=ax, format='%+2.0f dB') | |
| ax.set_title('Mel-Spectrogram') | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png', bbox_inches='tight') | |
| plt.close(fig) | |
| buf.seek(0) | |
| b64_str = base64.b64encode(buf.read()).decode('utf-8') | |
| return f"data:image/png;base64,{b64_str}" | |
| except Exception as e: | |
| print(f"Spectrogram base64 creation failed: {e}") | |
| return None | |
| # ========================================== | |
| # 3. ASR Helper (Whisper + CHAT Rules) | |
| # ========================================== | |
| def apply_chat_rules(transcription_result): | |
| """ | |
| Converts Whisper result into CHAT-like format AND inserts [PAUSE] tokens. | |
| """ | |
| formatted_text = [] | |
| segments = transcription_result.get('segments', []) | |
| last_end = 0 | |
| for seg in segments: | |
| gap = seg['start'] - last_end | |
| # Insert [PAUSE] token + CHAT marker | |
| if gap > 0.8: | |
| formatted_text.append("[PAUSE] (..)") | |
| elif gap > 0.3: | |
| formatted_text.append("[PAUSE] (.)") | |
| text = seg['text'].strip() | |
| # Repetitions (Basic Detection) | |
| words = text.split() | |
| processed_words = [] | |
| for i, w in enumerate(words): | |
| clean_w = re.sub(r'[^a-zA-Z]', '', w.lower()) | |
| if i > 0: | |
| prev_clean = re.sub(r'[^a-zA-Z]', '', words[i-1].lower()) | |
| if clean_w == prev_clean and clean_w: | |
| processed_words[-1] = f"{words[i-1]} [/]" | |
| processed_words.append(w) | |
| formatted_text.append(" ".join(processed_words)) | |
| last_end = seg['end'] | |
| return " ".join(formatted_text) |