File size: 7,003 Bytes
18b88bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
import time
from faster_whisper import WhisperModel
import logging
from flask import Flask, render_template, request, send_file, after_this_request
from werkzeug.utils import secure_filename

app = Flask(__name__)
app.logger.setLevel(logging.INFO)
app.config['UPLOAD_FOLDER'] = 'uploads'
app.config['OUTPUT_FOLDER'] = 'outputs'
app.config['ALLOWED_EXTENSIONS'] = {'mp3', 'wav', 'flac', 'mp4', 'mkv', 'mov', 'm4a', 'ogg', 'webm'}

os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['OUTPUT_FOLDER'], exist_ok=True)

# Model cache to avoid reloading the same model
model_cache = {}

def get_model(model_type):
    if model_type not in model_cache:
        model_path = f"/app/models/{model_type}"
        # Fallback for local development if /app/models doesn't exist
        if not os.path.exists(model_path):
            model_path = os.path.join(os.getcwd(), "models", model_type)
            
        app.logger.info(f"Loading model: {model_type} from {model_path}")
        model_cache[model_type] = WhisperModel(model_path, device="cpu", compute_type="int8")
    return model_cache[model_type]

def allowed_file(filename):
    return '.' in filename and \
           filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']

def format_srt_time(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    millis = int((seconds * 1000) % 1000)
    return f"{hours:02}:{minutes:02}:{secs:02},{millis:03}"

def transcribe_with_whisper(input_file, output_dir, language, model_type, max_duration):
    model = get_model(model_type)
    
    # Perform transcription
    transcribe_start = time.time()
    # faster-whisper returns a generator of segments and info
    segments, info = model.transcribe(
        input_file,
        language=language,
        word_timestamps=True
    )
    
    # Process segments into short chunks
    processed_segments = []
    
    for segment in segments:
        # If segment is already short enough or has no word timestamps, keep it as is
        if (segment.end - segment.start <= max_duration) or not segment.words:
            processed_segments.append({
                'start': segment.start,
                'end': segment.end,
                'text': segment.text.strip()
            })
        else:
            # Split segment into smaller chunks based on word timestamps
            current_chunk_words = []
            chunk_start = None
            
            for word in segment.words:
                if chunk_start is None:
                    chunk_start = word.start
                
                # If adding this word exceeds max_duration, finalize current chunk
                if current_chunk_words and (word.end - chunk_start > max_duration):
                    processed_segments.append({
                        'start': chunk_start,
                        'end': current_chunk_words[-1].end,
                        'text': " ".join([w.word.strip() for w in current_chunk_words])
                    })
                    current_chunk_words = [word]
                    chunk_start = word.start
                else:
                    current_chunk_words.append(word)
            
            # Add the last chunk
            if current_chunk_words:
                processed_segments.append({
                    'start': chunk_start,
                    'end': current_chunk_words[-1].end,
                    'text': " ".join([w.word.strip() for w in current_chunk_words])
                })

    transcribe_duration = time.time() - transcribe_start
    app.logger.info(f"[PROFILING] Transcribing file with {model_type} model took: {transcribe_duration:.2f} seconds")
    app.logger.info(f"[PROFILING] Detected language: {info.language} with probability {info.language_probability:.2f}")
    
    # Save to an SRT file
    srt_filename = "output.srt"
    srt_file = os.path.join(output_dir, srt_filename)
    
    srt_save_start = time.time()
    with open(srt_file, "w", encoding="utf-8") as f:
        for idx, segment in enumerate(processed_segments):
            start_time_srt = format_srt_time(segment['start'])
            end_time_srt = format_srt_time(segment['end'])
            
            f.write(f"{idx + 1}\n")
            f.write(f"{start_time_srt} --> {end_time_srt}\n")
            f.write(f"{segment['text']}\n\n")
            
    srt_save_duration = time.time() - srt_save_start
    app.logger.info(f"[PROFILING] Saving to SRT file took: {srt_save_duration:.2f} seconds")
    
    return srt_file

@app.route('/', methods=['GET'])
def index():
    return render_template('index.html')

@app.route('/transcribe', methods=['POST'])
def transcribe():
    if 'file' not in request.files:
        return 'No file uploaded', 400
    
    file = request.files['file']
    if file.filename == '':
        return 'No selected file', 400
    
    if not allowed_file(file.filename):
        return 'Invalid file type. Allowed types: ' + ', '.join(app.config['ALLOWED_EXTENSIONS']), 400
        
    language = request.form.get('language', 'en')
    model_type = request.form.get('model_type', 'accurate')
    try:
        max_duration = float(request.form.get('max_duration', 2.0))
        if not (1 <= max_duration <= 5):
            max_duration = 2.0
    except (ValueError, TypeError):
        max_duration = 2.0
    
    if file:
        filename = secure_filename(file.filename)
        input_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        
        # Save uploaded file
        save_start = time.time()
        file.save(input_path)
        save_duration = time.time() - save_start
        app.logger.info(f"[PROFILING] Saving uploaded file took: {save_duration:.2f} seconds")
        
        try:
            srt_path = transcribe_with_whisper(input_path, app.config['OUTPUT_FOLDER'], language, model_type, max_duration)
            
            @after_this_request
            def remove_files(response):
                try:
                    remove_start = time.time()
                    os.remove(input_path)
                    os.remove(srt_path)
                    remove_duration = time.time() - remove_start
                    app.logger.info(f"[PROFILING] Removing files took: {remove_duration:.2f} seconds")
                except Exception as e:
                    app.logger.error(f"Error removing files: {e}")
                return response
                
            return send_file(srt_path, as_attachment=True, download_name=f"{os.path.splitext(filename)[0]}.srt")
            
        except Exception as e:
            app.logger.error(f"Transcription error: {str(e)}")
            return f"An error occurred: {str(e)}", 500

if __name__ == '__main__':
    port = int(os.environ.get('PORT', 7860))
    app.run(host='0.0.0.0', port=port)


#############

#if __name__ == "__main__":
#    import uvicorn
#    uvicorn.run(app, host="0.0.0.0", port=7860)