File size: 4,360 Bytes
ae238b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import argparse
import logging
import os
import sys
import tempfile


# Set up basic logging first
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)

# Using direct imports
from env_vars import API_LOG_LEVEL

import torch
from flask import Flask, jsonify, send_file, send_from_directory
from flask_cors import CORS
from inference.audio_chunker import AudioChunker
from inference.audio_sentence_alignment import AudioAlignment
from inference.mms_model_pipeline import MMSModel
from transcriptions_blueprint import transcriptions_blueprint

# Configure logging with imported level
logging.basicConfig(stream=sys.stdout, level=API_LOG_LEVEL)


_model_loaded = False
_model_loading = False


def load_model():
    """Load the MMS model on startup - only called once"""
    global _model_loaded, _model_loading

    # If model is already loaded, return it
    if _model_loaded:
        logger.info("Model already loaded, skipping load")
        return

    # If model is currently being loaded by another thread/process, wait
    if _model_loading:
        logger.info("Model is currently being loaded, waiting...")
        return

    try:
        _model_loading = True
        logger.info("Loading MMS model...")

        # Initialize other components
        AudioChunker()
        AudioAlignment()

        # Initialize the new pipeline-based MMS model
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        MMSModel(device=device)

        logger.info("✓ MMS pipeline loaded successfully during server startup")

        _model_loaded = True
        logger.info(f"Models successfully loaded")
    except Exception as e:
        logger.error(f"Failed to load model: {str(e)}")
        _model_loaded = False
        return None
    finally:
        _model_loading = False


app = Flask(__name__)
app.register_blueprint(transcriptions_blueprint)
cors = CORS(
    app,
    resources={
        r"/*": {
            "origins": "*",
            "allow_headers": "*",
            "expose_headers": "*",
            "supports_credentials": True,
        }
    },
)

logger = logging.getLogger(__name__)
gunicorn_logger = logging.getLogger("gunicorn.error")
app.logger.handlers = gunicorn_logger.handlers
app.logger.setLevel(gunicorn_logger.level)

# Load model on startup - only once during app initialization
logger.info("Initializing application and loading model...")
if not _model_loaded:
    load_model()
else:
    logger.info("Model already loaded, skipping initialization")


# Frontend static file serving
@app.route("/")
def serve_frontend():
    """Serve the frontend index.html"""
    frontend_dist = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), "frontend", "dist"
    )
    return send_file(os.path.join(frontend_dist, "index.html"))


@app.route("/assets/<path:filename>")
def serve_assets(filename):
    """Serve frontend static assets"""
    frontend_dist = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), "frontend", "dist"
    )
    return send_from_directory(os.path.join(frontend_dist, "assets"), filename)


# Catch-all route for SPA routing - must be last
@app.route("/<path:path>")
def serve_spa(path):
    """Serve index.html for any unmatched routes (SPA routing)"""
    # If the path starts with 'api/', return 404 for API routes
    if path.startswith("api/"):
        return jsonify({"error": "API endpoint not found"}), 404

    # For all other paths, serve the frontend index.html
    frontend_dist = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), "frontend", "dist"
    )
    return send_file(os.path.join(frontend_dist, "index.html"))


@app.errorhandler(404)
def handle_404(e):
    return jsonify({"error": "Endpoint not found"}), 404


@app.errorhandler(500)
def handle_500(e):
    logger.error(f"Internal server error: {str(e)}")
    return jsonify({"error": "Internal server error"}), 500


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", default="0.0.0.0")
    parser.add_argument("--port", default=5000, type=int)
    parser.add_argument("--debug", default=True, type=bool)
    args = parser.parse_args()

    logger.info(f"Starting Translations API on {args.host}:{args.port}")
    app.run(host=args.host, port=args.port, debug=args.debug)