MATSYA_mvp / app.py
koyu008's picture
Update app.py
e4ec1e0 verified
import os
import logging
from flask import Flask, request, jsonify
from werkzeug.exceptions import RequestEntityTooLarge
from flask_cors import CORS
from flask_restx import Api, Resource, fields, reqparse
from werkzeug.datastructures import FileStorage
import torch
import torchaudio
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from pyannote.audio import Pipeline
from accelerate import Accelerator
# Configuration
MAX_CONTENT_LENGTH = int(os.getenv('MAX_CONTENT_LENGTH', 50 * 1024 * 1024)) # 50 MB
ALLOWED_EXTENSIONS = {'wav', 'mp3', 'flac', 'm4a'}
MODEL_NAME = os.getenv('WHISPER_MODEL', 'openai/whisper-large-v2') # Keep whisper-large-v2
HF_TOKEN = os.getenv('HF_TOKEN')
HOST = os.getenv('HOST', '0.0.0.0')
PORT = int(os.getenv('PORT', 7860))
DEFAULT_LANGUAGE = os.getenv('DEFAULT_LANGUAGE', 'hi')
DEFAULT_TASK = 'transcribe'
MAX_SEGMENTS = int(os.getenv('MAX_SEGMENTS', 10)) # Limit segments to reduce memory usage
# Initialize Flask
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = MAX_CONTENT_LENGTH
CORS(app)
api = Api(
app,
version='1.0',
title='MATSYA_mvp_API',
description='Upload audio → speaker diarization + segment-wise transcription/translation'
)
ns = api.namespace('transcribe', description='Audio operations')
# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('MATSYA_mvp_api')
# Initialize Accelerator for memory-efficient model loading
accelerator = Accelerator()
# Load Whisper on CPU with accelerate
device = 'cpu'
try:
processor = WhisperProcessor.from_pretrained(MODEL_NAME, cache_dir="/data/huggingface")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="/data/huggingface")
model, processor = accelerator.prepare(model, processor) # Optimize memory usage
except Exception as e:
logger.error(f"Failed to load Whisper model: {str(e)}")
raise
# Swagger response model
response_model = api.model('DiarizationResponse', {
'segments': fields.List(fields.Raw, description='Speaker segments with text')
})
# Request parser
parser = reqparse.RequestParser()
parser.add_argument('audio',
location='files',
type=FileStorage,
required=True,
help='Audio file (wav/mp3/flac/m4a)')
parser.add_argument('language',
type=str,
required=False,
help='Language code (e.g., hi for Hindi, en for English)')
parser.add_argument('task',
type=str,
required=False,
help='Task: transcribe or translate')
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.errorhandler(RequestEntityTooLarge)
def handle_too_large(e):
return jsonify({'error': 'File too large'}), 413
@app.errorhandler(Exception)
def handle_all(e):
logger.exception("Unhandled exception")
return jsonify({'error': str(e)}), 500
@ns.route('/')
class Transcription(Resource):
@ns.doc('upload_audio')
@ns.expect(parser)
@ns.marshal_with(response_model)
def post(self):
# Lazy-load Pyannote pipeline with explicit cache_dir
try:
diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization",
use_auth_token=HF_TOKEN,
cache_dir="/data/huggingface"
)
diarization_pipeline = accelerator.prepare(diarization_pipeline) # Optimize memory usage
except Exception as e:
logger.error(f"Failed to load pyannote pipeline: {str(e)}")
api.abort(500, f"Failed to load diarization pipeline: {str(e)}")
args = parser.parse_args()
audio_file = args['audio']
language = args['language'] or DEFAULT_LANGUAGE
task = args['task'] or DEFAULT_TASK
if not audio_file:
api.abort(400, "No audio file uploaded")
filename = audio_file.filename
if filename == '' or not allowed_file(filename):
api.abort(400, "Invalid file type")
try:
save_path = os.path.join('/tmp', filename)
audio_file.save(save_path)
except Exception as e:
logger.error(f"Failed to save audio file: {str(e)}")
api.abort(500, f"Failed to save audio file: {str(e)}")
# Run diarization
try:
diarization = diarization_pipeline(save_path)
except Exception as e:
os.remove(save_path)
logger.error(f"Diarization failed: {str(e)}")
api.abort(500, f"Diarization failed: {str(e)}")
try:
waveform, sr = torchaudio.load(save_path)
except Exception as e:
os.remove(save_path)
logger.error(f"Failed to load audio: {str(e)}")
api.abort(500, f"Failed to load audio: {str(e)}")
segments = []
for i, (turn, _, speaker) in enumerate(diarization.itertracks(yield_label=True)):
if i >= MAX_SEGMENTS: # Limit segments to reduce memory usage
logger.warning(f"Stopped processing after {MAX_SEGMENTS} segments to avoid OOM")
break
start, end = turn.start, turn.end
s_frame = int(start * sr)
e_frame = int(end * sr)
chunk = waveform[:, s_frame:e_frame]
# Resample
if sr != 16000:
chunk = torchaudio.transforms.Resample(sr, 16000)(chunk)
seg_sr = 16000
else:
seg_sr = sr
if chunk.shape[0] > 1:
chunk = chunk.mean(dim=0, keepdim=True)
# Whisper inference
try:
inputs = processor(
chunk.squeeze().numpy(),
sampling_rate=seg_sr,
return_tensors='pt'
).input_features.to(device)
predicted_ids = model.generate(inputs, language=language, task=task)
text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
except Exception as e:
logger.error(f"Whisper inference failed for segment {i}: {str(e)}")
continue # Skip failed segments to avoid crashing
segments.append({
'speaker': speaker,
'start': round(start, 2),
'end': round(end, 2),
'text': text,
'language': language,
'task': task
})
os.remove(save_path)
return {'segments': segments}
if __name__ == '__main__':
app.run(host=HOST, port=PORT, debug=True)