File size: 4,051 Bytes
64a29d6
f6beab7
 
d20cd0c
f6beab7
 
 
 
df8b5ec
 
 
 
 
 
 
 
f6beab7
 
 
 
 
 
 
 
 
d20cd0c
f6beab7
 
8f5fb37
 
f6beab7
8f5fb37
 
 
 
 
f6beab7
8f5fb37
 
 
 
 
f6beab7
8f5fb37
 
f6beab7
8f5fb37
 
f6beab7
8f5fb37
 
 
f6beab7
8f5fb37
f6beab7
8f5fb37
 
d20cd0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6beab7
8f5fb37
d20cd0c
f6beab7
8f5fb37
 
 
f6beab7
8f5fb37
 
 
 
 
 
 
 
df8b5ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import streamlit as st
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torch
import librosa
import srt
from datetime import timedelta

# ์˜ค๋””์˜ค ํŒŒ์ผ์„ 5์ดˆ ๊ฐ„๊ฒฉ์œผ๋กœ ๋‚˜๋ˆ„๋Š” ํ•จ์ˆ˜
def split_audio(audio, sr, segment_duration=5):
    segments = []
    for i in range(0, len(audio), int(segment_duration * sr)):
        segment = audio[i:i + int(segment_duration * sr)]
        segments.append(segment)
    return segments

# ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋“œ
@st.cache_resource
def load_model():
    model = WhisperForConditionalGeneration.from_pretrained("lcjln/AIME_Project_The_Final")
    processor = WhisperProcessor.from_pretrained("lcjln/AIME_The_Final")
    return model, processor

model, processor = load_model()

# Streamlit ์›น ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์ธํ„ฐํŽ˜์ด์Šค
st.title("Whisper ์ž๋ง‰ ์ƒ์„ฑ๊ธฐ")

# ์—ฌ๋Ÿฌ WAV ํŒŒ์ผ ์—…๋กœ๋“œ
uploaded_files = st.file_uploader("์—ฌ๊ธฐ์— WAV ํŒŒ์ผ๋“ค์„ ๋“œ๋ž˜๊ทธ ์•ค ๋“œ๋กญ ํ•˜์„ธ์š”", type=["wav"], accept_multiple_files=True)

# ํŒŒ์ผ ๋ชฉ๋ก์„ ๋ณด์—ฌ์คŒ
if uploaded_files:
    st.write("์—…๋กœ๋“œ๋œ ํŒŒ์ผ ๋ชฉ๋ก:")
    for uploaded_file in uploaded_files:
        st.write(uploaded_file.name)

    # ์‹คํ–‰ ๋ฒ„ํŠผ
    if st.button("์‹คํ–‰"):
        combined_subs = []
        last_end_time = timedelta(0)
        subtitle_index = 1

        for uploaded_file in uploaded_files:
            st.write(f"์ฒ˜๋ฆฌ ์ค‘: {uploaded_file.name}")

            # ์ง„ํ–‰๋ฐ” ์ดˆ๊ธฐํ™”
            progress_bar = st.progress(0)

            # WAV ํŒŒ์ผ ๋กœ๋“œ ๋ฐ ์ฒ˜๋ฆฌ
            st.write("์˜ค๋””์˜ค ํŒŒ์ผ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
            audio, sr = librosa.load(uploaded_file, sr=16000)

            progress_bar.progress(50)

            # Whisper ๋ชจ๋ธ๋กœ ๋ณ€ํ™˜
            st.write("๋ชจ๋ธ์„ ํ†ตํ•ด ์ž๋ง‰์„ ์ƒ์„ฑํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
            segments = split_audio(audio, sr, segment_duration=5)

            for i, segment in enumerate(segments):
                inputs = processor(segment, return_tensors="pt", sampling_rate=16000)
                with torch.no_grad():
                    outputs = model.generate(inputs["input_features"], max_length=2048, return_dict_in_generate=True, output_scores=True)

                # ํ…์ŠคํŠธ ๋””์ฝ”๋”ฉ
                transcription = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0].strip()

                # ์‹ ๋ขฐ๋„ ์ ์ˆ˜ ๊ณ„์‚ฐ (์ถ”๊ฐ€์ ์ธ ์‹ ๋ขฐ๋„ ํ•„ํ„ฐ๋ง ์ ์šฉ)
                avg_logit_score = torch.mean(outputs.scores[-1]).item()

                # ์‹ ๋ขฐ๋„ ์ ์ˆ˜๊ฐ€ ๋‚ฎ๊ฑฐ๋‚˜ ํ…์ŠคํŠธ๊ฐ€ ๋น„์–ด์žˆ๋Š” ๊ฒฝ์šฐ ๋ฌด์‹œ
                if transcription and avg_logit_score > -5.0:
                    segment_duration = librosa.get_duration(y=segment, sr=sr)
                    end_time = last_end_time + timedelta(seconds=segment_duration)

                    combined_subs.append(
                        srt.Subtitle(
                            index=subtitle_index,
                            start=last_end_time,
                            end=end_time,
                            content=transcription
                        )
                    )
                    last_end_time = end_time
                    subtitle_index += 1

            progress_bar.progress(100)
            st.success(f"{uploaded_file.name}์˜ ์ž๋ง‰์ด ์„ฑ๊ณต์ ์œผ๋กœ ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!")

        # ๋ชจ๋“  ์ž๋ง‰์„ ํ•˜๋‚˜์˜ SRT ํŒŒ์ผ๋กœ ์ €์žฅ
        st.write("์ตœ์ข… SRT ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
        srt_content = srt.compose(combined_subs)

        final_srt_file_path = "combined_output.srt"
        with open(final_srt_file_path, "w", encoding="utf-8") as f:
            f.write(srt_content)

        st.success("์ตœ์ข… SRT ํŒŒ์ผ์ด ์„ฑ๊ณต์ ์œผ๋กœ ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!")

        # ์ตœ์ข… SRT ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ๋ฒ„ํŠผ
        with open(final_srt_file_path, "rb") as srt_file:
            st.download_button(label="SRT ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ", data=srt_file, file_name=final_srt_file_path, mime="text/srt")