File size: 7,843 Bytes
34ac5c4
98a891c
34ac5c4
 
 
98a891c
 
 
f3ada2c
29c2fd9
fa01798
29c2fd9
f3ada2c
29c2fd9
f3ada2c
34ac5c4
29c2fd9
fa01798
f3ada2c
29c2fd9
f3ada2c
 
29c2fd9
f3ada2c
 
 
29c2fd9
 
f3ada2c
29c2fd9
 
f3ada2c
29c2fd9
34ac5c4
fa01798
 
f3ada2c
 
98a891c
f3ada2c
98a891c
 
 
 
 
 
 
 
fa01798
98a891c
 
4d3b6d4
 
 
 
 
 
98a891c
 
 
 
4d3b6d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98a891c
4d3b6d4
98a891c
4d3b6d4
 
 
 
 
fa01798
 
 
98a891c
4d3b6d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98a891c
4d3b6d4
 
98a891c
 
fa01798
98a891c
 
 
 
 
 
 
 
 
 
 
 
 
fa01798
98a891c
5a94ac4
98a891c
 
 
 
 
 
 
5a94ac4
 
 
 
 
 
 
 
 
 
 
98a891c
5a94ac4
98a891c
fa01798
98a891c
 
 
 
 
 
 
 
 
34ac5c4
f3ada2c
fa01798
34ac5c4
8862332
 
fa01798
 
98a891c
fa01798
34ac5c4
 
98a891c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import os
import json
import gradio as gr
import nemo.collections.asr as nemo_asr
import spaces
import librosa
import numpy as np
import torch
from huggingface_hub import hf_hub_download

# 1. SETTINGS - Matches your public repo perfectly
REPO_ID = "Ali428/parakeet-weights"
FILENAME = "parakeet_final.nemo"

print(f"πŸš€ Initializing FYP ASR Engine...")

try:
    # 2. DOWNLOAD THE FILE (No token needed for public repos)
    print(f"Downloading {FILENAME} from {REPO_ID}...")
    model_path = hf_hub_download(
        repo_id=REPO_ID, 
        filename=FILENAME
    )
    
    # 3. LOAD THE NEMO MODEL
    print(f"Loading model into memory...")
    model = nemo_asr.models.ASRModel.restore_from(model_path)
    model.eval()
    print("βœ… Model loaded successfully!")

except Exception as e:
    print(f"❌ ERROR DURING STARTUP: {e}")
    raise e

# 4. TRANSCRIPTION FUNCTION
@spaces.GPU(duration=120)  # Increased duration slightly for long videos
def transcribe(audio_path):
    if audio_path is None:
        return json.dumps({"success": False, "error": "Please upload an audio file."})
    
    try:
        # 1. Get audio duration using librosa
        audio, sr = librosa.load(audio_path, sr=16000, mono=True)
        duration = len(audio) / sr
        
        # 2. Transcribe with return_hypotheses=True to get timestamps
        transcription = model.transcribe([audio_path], return_hypotheses=True)
        
        # Extract the hypothesis object
        hypothesis = transcription[0][0] if isinstance(transcription[0], list) else transcription[0]
        full_text = hypothesis.text

        # DEBUG: log timestep info to diagnose CPU vs GPU timestamp differences
        _ts = getattr(hypothesis, 'timestep', None)
        print(f"DEBUG has_timestep={hasattr(hypothesis, 'timestep')} | type={type(_ts)} | len={len(_ts) if _ts is not None and hasattr(_ts, '__len__') else 'N/A'} | sample={_ts[:5] if _ts is not None and hasattr(_ts, '__len__') and len(_ts) > 0 else _ts}")
        print(f"DEBUG has_words={hasattr(hypothesis, 'words')} | device=cuda:{torch.cuda.is_available()}")

        
        # 3. Process timestamps
        segments = []
        try:
            timestep = getattr(hypothesis, 'timestep', None)
            words = getattr(hypothesis, 'words', None) or full_text.split()

            # DEBUG: log to HF Space logs so we can inspect CPU behaviour
            print(f"DEBUG timestep type={type(timestep)} | has_words={hasattr(hypothesis, 'words')} | cuda={torch.cuda.is_available()}")
            if timestep is not None:
                print(f"DEBUG timestep value (raw)={str(timestep)[:200]}")

            # NeMo may return timestep as a Tensor, list, dict, or None
            if timestep is None:
                raise ValueError("timestep is None β€” will use fallback")

            # If it's a dict (e.g. {'timestep': [...], 'word': [...]})
            if isinstance(timestep, dict):
                timestep = timestep.get('timestep') or timestep.get('word') or list(timestep.values())[0]

            # Convert to numpy array
            if isinstance(timestep, torch.Tensor):
                timestep_array = timestep.cpu().numpy()
            else:
                timestep_array = np.array(timestep, dtype=np.float32)

            if len(timestep_array) == 0:
                raise ValueError("timestep array is empty β€” will use fallback")

            # Parakeet TDT uses 8x encoder subsampling (4x conv + 2x additional)
            # so each encoder frame = 8 * 10ms = 80ms of real audio time
            frame_duration = 0.08
            words_per_segment = 15

            frames_per_word = len(timestep_array) / len(words)
            word_timestamps = []

            for i, word in enumerate(words):
                start_frame = int(i * frames_per_word)
                end_frame   = int((i + 1) * frames_per_word)
                start_frame = min(start_frame, len(timestep_array) - 1)
                end_frame   = min(end_frame,   len(timestep_array))

                if end_frame > start_frame:
                    actual_start_frame = int(timestep_array[start_frame])
                    actual_end_frame   = int(timestep_array[min(end_frame - 1, len(timestep_array) - 1)])
                else:
                    actual_start_frame = int(timestep_array[start_frame])
                    actual_end_frame   = actual_start_frame

                word_timestamps.append({
                    'word':  word,
                    'start': round(actual_start_frame * frame_duration, 3),
                    'end':   round(actual_end_frame   * frame_duration, 3),
                })

            # Group into segments of words_per_segment words each
            for i in range(0, len(word_timestamps), words_per_segment):
                seg = word_timestamps[i:i + words_per_segment]
                if seg:
                    segments.append({
                        'text':  ' '.join(w['word'] for w in seg),
                        'start': seg[0]['start'],
                        'end':   seg[-1]['end'],
                    })


        except Exception as e:
            print(f"Warning: Timestamp extraction failed: {e}")
            # Fallback segments method
            words = full_text.split()
            words_per_segment = 15
            segment_duration = duration / max(1, len(words) / words_per_segment)
            
            for i in range(0, len(words), words_per_segment):
                segment_words = words[i:i + words_per_segment]
                segment_idx = i // words_per_segment
                segments.append({
                    'text': ' '.join(segment_words),
                    'start': round(segment_idx * segment_duration, 3),
                    'end': round((segment_idx + 1) * segment_duration, 3)
                })

        # 4. Construct response identical to local backend
        word_count = len(full_text.split())

        result = {
            "success": True,
            "text": full_text,
            "segment_timestamps": segments,
            "model": "nvidia/parakeet-tdt-0.6b-v3",
            "device": "cuda" if torch.cuda.is_available() else "cpu",
            "audio_duration": round(duration, 2),
            "word_count": word_count,
            # DEBUG: remove once timestamps are confirmed working
            "debug_info": {
                "has_timestep": hasattr(hypothesis, 'timestep'),
                "timestep_type": str(type(getattr(hypothesis, 'timestep', None))),
                "timestep_len": len(getattr(hypothesis, 'timestep', [])) if hasattr(hypothesis, 'timestep') else 0,
                "timestep_sample": str(getattr(hypothesis, 'timestep', None))[:100] if hasattr(hypothesis, 'timestep') else None,
                "has_words": hasattr(hypothesis, 'words'),
                "cuda_available": torch.cuda.is_available(),
                "segments_from_fallback": len(segments) > 0 and segments[-1]['end'] > 190,
            }
        }

        
        # Return as pretty JSON string for the Gradio UI
        return json.dumps(result, indent=2)
        
    except Exception as e:
        import traceback
        traceback.print_exc()
        return json.dumps({
            "success": False,
            "error": str(e)
        }, indent=2)

# 5. GRADIO INTERFACE
# For API usage, we can skip the UI, but we'll leave it for debugging
demo = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(type="filepath", label="Upload Audio"),
    outputs=gr.Code(language="json", label="Transcription Result JSON"), # Changed to Code/JSON for better formatting
    title="Ali's FYP: Parakeet ASR Engine",
    description="Running on NVIDIA A100 (ZeroGPU). API-ready endpoint.",
    api_name="transcribe" # Allows api requests to /api/transcribe
)

if __name__ == "__main__":
    demo.launch()