File size: 3,252 Bytes
28a4a57
 
2f93a93
28a4a57
 
 
 
 
 
d89e14f
28a4a57
 
 
 
 
 
 
 
 
 
 
d89e14f
28a4a57
fd6cbd8
28a4a57
fd6cbd8
d89e14f
2f93a93
28a4a57
d89e14f
fd6cbd8
28a4a57
 
d89e14f
2f93a93
d89e14f
 
28a4a57
d89e14f
28a4a57
 
 
d89e14f
 
 
 
28a4a57
d89e14f
 
 
2f93a93
 
 
d89e14f
28a4a57
fd6cbd8
8ee6822
28a4a57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ee6822
28a4a57
 
8ee6822
28a4a57
 
 
8ee6822
d89e14f
28a4a57
 
 
fd6cbd8
d89e14f
28a4a57
 
2f93a93
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
import gradio as gr
import whisperx
import torch
import tempfile
import os
import uvicorn
from threading import Thread

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"

print(f"🚀 Device: {device}, Compute: {compute_type}")

# Create FastAPI app
app = FastAPI(title="WhisperX Alignment API")

def process_audio(audio_path: str, language: str = "en"):
    """Core alignment logic"""
    try:
        print(f"📝 Processing {audio_path} ({language})...")
        
        # Load model
        model = whisperx.load_model("base", device=device, compute_type=compute_type)
        
        # Transcribe
        result = model.transcribe(audio_path, language=language)
        
        # Align
        align_model, metadata = whisperx.load_align_model(language_code=language, device=device)
        aligned = whisperx.align(result["segments"], align_model, metadata, audio_path, device=device)
        
        # Extract word segments
        word_segments = []
        for segment in aligned["segments"]:
            for word in segment.get("words", []):
                word_segments.append({
                    "word": word["word"].strip(),
                    "start": round(word["start"], 2),
                    "end": round(word["end"], 2)
                })
        
        duration = aligned["segments"][-1]["end"] if aligned["segments"] else 0
        
        return {
            "word_segments": word_segments,
            "duration": round(duration, 2),
            "word_count": len(word_segments),
            "language": language,
            "device": device
        }
    except Exception as e:
        print(f"❌ Error: {e}")
        return {"error": str(e)}

# FastAPI endpoint
@app.post("/align")
async def align_audio_api(
    audio_file: UploadFile = File(...),
    language: str = Form("en")
):
    """REST API endpoint for audio alignment"""
    temp_path = None
    try:
        # Save temp file
        with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
            content = await audio_file.read()
            tmp.write(content)
            temp_path = tmp.name
        
        # Process
        result = process_audio(temp_path, language)
        return JSONResponse(result)
    
    finally:
        if temp_path and os.path.exists(temp_path):
            os.unlink(temp_path)

@app.get("/")
def health():
    return {"status": "healthy", "device": device}

# Gradio interface
def align_gradio(audio_file, language="en"):
    """Gradio UI wrapper"""
    if not audio_file:
        return {"error": "No file"}
    return process_audio(audio_file, language)

gradio_app = gr.Interface(
    fn=align_gradio,
    inputs=[
        gr.Audio(type="filepath", label="Audio"),
        gr.Textbox(value="en", label="Language")
    ],
    outputs=gr.JSON(label="Result"),
    title="🎯 WhisperX Alignment",
    description="Upload audio for word-level timestamps"
)

# Mount Gradio to FastAPI
app = gr.mount_gradio_app(app, gradio_app, path="/")

# Launch
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)