Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from flask import Flask, request, jsonify | |
| from werkzeug.exceptions import RequestEntityTooLarge | |
| from flask_cors import CORS | |
| from flask_restx import Api, Resource, fields, reqparse | |
| from werkzeug.datastructures import FileStorage | |
| import torch | |
| import torchaudio | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| from pyannote.audio import Pipeline | |
| from accelerate import Accelerator | |
| # Configuration | |
| MAX_CONTENT_LENGTH = int(os.getenv('MAX_CONTENT_LENGTH', 50 * 1024 * 1024)) # 50 MB | |
| ALLOWED_EXTENSIONS = {'wav', 'mp3', 'flac', 'm4a'} | |
| MODEL_NAME = os.getenv('WHISPER_MODEL', 'openai/whisper-large-v2') # Keep whisper-large-v2 | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| HOST = os.getenv('HOST', '0.0.0.0') | |
| PORT = int(os.getenv('PORT', 7860)) | |
| DEFAULT_LANGUAGE = os.getenv('DEFAULT_LANGUAGE', 'hi') | |
| DEFAULT_TASK = 'transcribe' | |
| MAX_SEGMENTS = int(os.getenv('MAX_SEGMENTS', 10)) # Limit segments to reduce memory usage | |
| # Initialize Flask | |
| app = Flask(__name__) | |
| app.config['MAX_CONTENT_LENGTH'] = MAX_CONTENT_LENGTH | |
| CORS(app) | |
| api = Api( | |
| app, | |
| version='1.0', | |
| title='MATSYA_mvp_API', | |
| description='Upload audio → speaker diarization + segment-wise transcription/translation' | |
| ) | |
| ns = api.namespace('transcribe', description='Audio operations') | |
| # Logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger('MATSYA_mvp_api') | |
| # Initialize Accelerator for memory-efficient model loading | |
| accelerator = Accelerator() | |
| # Load Whisper on CPU with accelerate | |
| device = 'cpu' | |
| try: | |
| processor = WhisperProcessor.from_pretrained(MODEL_NAME, cache_dir="/data/huggingface") | |
| model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="/data/huggingface") | |
| model, processor = accelerator.prepare(model, processor) # Optimize memory usage | |
| except Exception as e: | |
| logger.error(f"Failed to load Whisper model: {str(e)}") | |
| raise | |
| # Swagger response model | |
| response_model = api.model('DiarizationResponse', { | |
| 'segments': fields.List(fields.Raw, description='Speaker segments with text') | |
| }) | |
| # Request parser | |
| parser = reqparse.RequestParser() | |
| parser.add_argument('audio', | |
| location='files', | |
| type=FileStorage, | |
| required=True, | |
| help='Audio file (wav/mp3/flac/m4a)') | |
| parser.add_argument('language', | |
| type=str, | |
| required=False, | |
| help='Language code (e.g., hi for Hindi, en for English)') | |
| parser.add_argument('task', | |
| type=str, | |
| required=False, | |
| help='Task: transcribe or translate') | |
| def allowed_file(filename): | |
| return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
| def handle_too_large(e): | |
| return jsonify({'error': 'File too large'}), 413 | |
| def handle_all(e): | |
| logger.exception("Unhandled exception") | |
| return jsonify({'error': str(e)}), 500 | |
| class Transcription(Resource): | |
| def post(self): | |
| # Lazy-load Pyannote pipeline with explicit cache_dir | |
| try: | |
| diarization_pipeline = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization", | |
| use_auth_token=HF_TOKEN, | |
| cache_dir="/data/huggingface" | |
| ) | |
| diarization_pipeline = accelerator.prepare(diarization_pipeline) # Optimize memory usage | |
| except Exception as e: | |
| logger.error(f"Failed to load pyannote pipeline: {str(e)}") | |
| api.abort(500, f"Failed to load diarization pipeline: {str(e)}") | |
| args = parser.parse_args() | |
| audio_file = args['audio'] | |
| language = args['language'] or DEFAULT_LANGUAGE | |
| task = args['task'] or DEFAULT_TASK | |
| if not audio_file: | |
| api.abort(400, "No audio file uploaded") | |
| filename = audio_file.filename | |
| if filename == '' or not allowed_file(filename): | |
| api.abort(400, "Invalid file type") | |
| try: | |
| save_path = os.path.join('/tmp', filename) | |
| audio_file.save(save_path) | |
| except Exception as e: | |
| logger.error(f"Failed to save audio file: {str(e)}") | |
| api.abort(500, f"Failed to save audio file: {str(e)}") | |
| # Run diarization | |
| try: | |
| diarization = diarization_pipeline(save_path) | |
| except Exception as e: | |
| os.remove(save_path) | |
| logger.error(f"Diarization failed: {str(e)}") | |
| api.abort(500, f"Diarization failed: {str(e)}") | |
| try: | |
| waveform, sr = torchaudio.load(save_path) | |
| except Exception as e: | |
| os.remove(save_path) | |
| logger.error(f"Failed to load audio: {str(e)}") | |
| api.abort(500, f"Failed to load audio: {str(e)}") | |
| segments = [] | |
| for i, (turn, _, speaker) in enumerate(diarization.itertracks(yield_label=True)): | |
| if i >= MAX_SEGMENTS: # Limit segments to reduce memory usage | |
| logger.warning(f"Stopped processing after {MAX_SEGMENTS} segments to avoid OOM") | |
| break | |
| start, end = turn.start, turn.end | |
| s_frame = int(start * sr) | |
| e_frame = int(end * sr) | |
| chunk = waveform[:, s_frame:e_frame] | |
| # Resample | |
| if sr != 16000: | |
| chunk = torchaudio.transforms.Resample(sr, 16000)(chunk) | |
| seg_sr = 16000 | |
| else: | |
| seg_sr = sr | |
| if chunk.shape[0] > 1: | |
| chunk = chunk.mean(dim=0, keepdim=True) | |
| # Whisper inference | |
| try: | |
| inputs = processor( | |
| chunk.squeeze().numpy(), | |
| sampling_rate=seg_sr, | |
| return_tensors='pt' | |
| ).input_features.to(device) | |
| predicted_ids = model.generate(inputs, language=language, task=task) | |
| text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| except Exception as e: | |
| logger.error(f"Whisper inference failed for segment {i}: {str(e)}") | |
| continue # Skip failed segments to avoid crashing | |
| segments.append({ | |
| 'speaker': speaker, | |
| 'start': round(start, 2), | |
| 'end': round(end, 2), | |
| 'text': text, | |
| 'language': language, | |
| 'task': task | |
| }) | |
| os.remove(save_path) | |
| return {'segments': segments} | |
| if __name__ == '__main__': | |
| app.run(host=HOST, port=PORT, debug=True) |