Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| from faster_whisper import WhisperModel | |
| import logging | |
| from flask import Flask, render_template, request, send_file, after_this_request | |
| from werkzeug.utils import secure_filename | |
| app = Flask(__name__) | |
| app.logger.setLevel(logging.INFO) | |
| app.config['UPLOAD_FOLDER'] = 'uploads' | |
| app.config['OUTPUT_FOLDER'] = 'outputs' | |
| app.config['ALLOWED_EXTENSIONS'] = {'mp3', 'wav', 'flac', 'mp4', 'mkv', 'mov', 'm4a', 'ogg', 'webm'} | |
| os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
| os.makedirs(app.config['OUTPUT_FOLDER'], exist_ok=True) | |
| # Model cache to avoid reloading the same model | |
| model_cache = {} | |
| def get_model(model_type): | |
| if model_type not in model_cache: | |
| model_path = f"/app/models/{model_type}" | |
| # Fallback for local development if /app/models doesn't exist | |
| if not os.path.exists(model_path): | |
| model_path = os.path.join(os.getcwd(), "models", model_type) | |
| app.logger.info(f"Loading model: {model_type} from {model_path}") | |
| model_cache[model_type] = WhisperModel(model_path, device="cpu", compute_type="int8") | |
| return model_cache[model_type] | |
| def allowed_file(filename): | |
| return '.' in filename and \ | |
| filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS'] | |
| def format_srt_time(seconds): | |
| hours = int(seconds // 3600) | |
| minutes = int((seconds % 3600) // 60) | |
| secs = int(seconds % 60) | |
| millis = int((seconds * 1000) % 1000) | |
| return f"{hours:02}:{minutes:02}:{secs:02},{millis:03}" | |
| def transcribe_with_whisper(input_file, output_dir, language, model_type, max_duration): | |
| model = get_model(model_type) | |
| # Perform transcription | |
| transcribe_start = time.time() | |
| # faster-whisper returns a generator of segments and info | |
| segments, info = model.transcribe( | |
| input_file, | |
| language=language, | |
| word_timestamps=True | |
| ) | |
| # Process segments into short chunks | |
| processed_segments = [] | |
| for segment in segments: | |
| # If segment is already short enough or has no word timestamps, keep it as is | |
| if (segment.end - segment.start <= max_duration) or not segment.words: | |
| processed_segments.append({ | |
| 'start': segment.start, | |
| 'end': segment.end, | |
| 'text': segment.text.strip() | |
| }) | |
| else: | |
| # Split segment into smaller chunks based on word timestamps | |
| current_chunk_words = [] | |
| chunk_start = None | |
| for word in segment.words: | |
| if chunk_start is None: | |
| chunk_start = word.start | |
| # If adding this word exceeds max_duration, finalize current chunk | |
| if current_chunk_words and (word.end - chunk_start > max_duration): | |
| processed_segments.append({ | |
| 'start': chunk_start, | |
| 'end': current_chunk_words[-1].end, | |
| 'text': " ".join([w.word.strip() for w in current_chunk_words]) | |
| }) | |
| current_chunk_words = [word] | |
| chunk_start = word.start | |
| else: | |
| current_chunk_words.append(word) | |
| # Add the last chunk | |
| if current_chunk_words: | |
| processed_segments.append({ | |
| 'start': chunk_start, | |
| 'end': current_chunk_words[-1].end, | |
| 'text': " ".join([w.word.strip() for w in current_chunk_words]) | |
| }) | |
| transcribe_duration = time.time() - transcribe_start | |
| app.logger.info(f"[PROFILING] Transcribing file with {model_type} model took: {transcribe_duration:.2f} seconds") | |
| app.logger.info(f"[PROFILING] Detected language: {info.language} with probability {info.language_probability:.2f}") | |
| # Save to an SRT file | |
| srt_filename = "output.srt" | |
| srt_file = os.path.join(output_dir, srt_filename) | |
| srt_save_start = time.time() | |
| with open(srt_file, "w", encoding="utf-8") as f: | |
| for idx, segment in enumerate(processed_segments): | |
| start_time_srt = format_srt_time(segment['start']) | |
| end_time_srt = format_srt_time(segment['end']) | |
| f.write(f"{idx + 1}\n") | |
| f.write(f"{start_time_srt} --> {end_time_srt}\n") | |
| f.write(f"{segment['text']}\n\n") | |
| srt_save_duration = time.time() - srt_save_start | |
| app.logger.info(f"[PROFILING] Saving to SRT file took: {srt_save_duration:.2f} seconds") | |
| return srt_file | |
| def index(): | |
| return render_template('index.html') | |
| def transcribe(): | |
| if 'file' not in request.files: | |
| return 'No file uploaded', 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return 'No selected file', 400 | |
| if not allowed_file(file.filename): | |
| return 'Invalid file type. Allowed types: ' + ', '.join(app.config['ALLOWED_EXTENSIONS']), 400 | |
| language = request.form.get('language', 'en') | |
| model_type = request.form.get('model_type', 'accurate') | |
| try: | |
| max_duration = float(request.form.get('max_duration', 2.0)) | |
| if not (1 <= max_duration <= 5): | |
| max_duration = 2.0 | |
| except (ValueError, TypeError): | |
| max_duration = 2.0 | |
| if file: | |
| filename = secure_filename(file.filename) | |
| input_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| # Save uploaded file | |
| save_start = time.time() | |
| file.save(input_path) | |
| save_duration = time.time() - save_start | |
| app.logger.info(f"[PROFILING] Saving uploaded file took: {save_duration:.2f} seconds") | |
| try: | |
| srt_path = transcribe_with_whisper(input_path, app.config['OUTPUT_FOLDER'], language, model_type, max_duration) | |
| def remove_files(response): | |
| try: | |
| remove_start = time.time() | |
| os.remove(input_path) | |
| os.remove(srt_path) | |
| remove_duration = time.time() - remove_start | |
| app.logger.info(f"[PROFILING] Removing files took: {remove_duration:.2f} seconds") | |
| except Exception as e: | |
| app.logger.error(f"Error removing files: {e}") | |
| return response | |
| return send_file(srt_path, as_attachment=True, download_name=f"{os.path.splitext(filename)[0]}.srt") | |
| except Exception as e: | |
| app.logger.error(f"Transcription error: {str(e)}") | |
| return f"An error occurred: {str(e)}", 500 | |
| if __name__ == '__main__': | |
| port = int(os.environ.get('PORT', 7860)) | |
| app.run(host='0.0.0.0', port=port) | |
| ############# | |
| #if __name__ == "__main__": | |
| # import uvicorn | |
| # uvicorn.run(app, host="0.0.0.0", port=7860) |