File size: 6,525 Bytes
8fb0be5
4048265
ee531be
59f6126
 
df7e4e5
93ba2b1
b83f300
024f740
4b331f0
ad9e842
 
 
 
 
 
 
56a81fb
ad9e842
4b331f0
 
e8a4c9c
 
 
 
ee531be
09b358f
 
7dc42bb
8ac5c0e
 
78f79fc
2b3433c
 
 
 
 
78f79fc
 
2b3433c
 
 
69b2504
09b358f
29a10e5
c9323c5
 
791adc1
4b331f0
ee531be
4b331f0
ee531be
4b331f0
09b358f
 
 
 
 
060a1e0
09b358f
4b331f0
 
 
ee531be
4b331f0
 
09b358f
df7e4e5
 
09b358f
ee531be
4048265
09b358f
 
 
 
 
 
 
4380489
09b358f
 
 
4380489
09b358f
8ac5c0e
09b358f
ee531be
1c8f6d7
060a1e0
4048265
 
 
09b358f
623c1fa
09b358f
 
623c1fa
4b331f0
e86a1dc
2b3433c
e86a1dc
a8c8823
c9323c5
 
46f1669
2b3433c
69b2504
c9323c5
69b2504
93ba2b1
b83f300
2b3433c
b83f300
 
78f79fc
 
 
791adc1
4b331f0
2b3433c
69b2504
2b3433c
78f79fc
 
7a9eeaf
096ee7c
78f79fc
 
4b331f0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import time
import streamlit as st
import whisperx
import torch
from utils import convert_segments_object_to_text, check_password, convert_segments_object_to_text_simple, get_gpu_memory_info
from gigiachat_requests import get_access_token, get_completion_from_gigachat, get_number_of_tokens, process_transcribation_with_gigachat
from openai_requests import get_completion_from_openai, process_transcribation_with_assistant

if check_password():    
    if torch.cuda.is_available():
        print('GPU доступен')
    else:
        print('GPU не доступен')

    print(f'Версия торча: {torch.__version__}')
    print(f'Версия cuda: {torch.version.cuda}')
    print(f'Версия cudnn: {torch.backends.cudnn.version()}')
    
    st.title('Audio Transcription App')
    st.sidebar.title("Settings")
    
    device = os.getenv('DEVICE')
    batch_size = int(os.getenv('BATCH_SIZE'))
    compute_type = os.getenv('COMPUTE_TYPE')

    initial_base_prompt = os.getenv('BASE_PROMPT')
    initial_processing_prompt = os.getenv('PROCCESS_PROMPT')

    min_speakers = st.sidebar.number_input("Минимальное количество спикеров", min_value=1, value=2)
    max_speakers = st.sidebar.number_input("Максимальное количество спикеров", min_value=1, value=2)
    llm = st.sidebar.selectbox("Производитель LLM", ["Сбер", "OpenAI", "Qwen"], index=0)

    if llm == "Сбер":
        options = ["GigaChat-Plus", "GigaChat", "GigaChat-Pro"]
    elif llm == "OpenAI":
        options = ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]
    elif llm == "Qwen":
        options = ["Qwen/Qwen2-7B-Instruct"]
    else:
        options = []

    llm_model = st.sidebar.selectbox("Модель", options, index=0)
    base_prompt = st.sidebar.text_area("Промпт для резюмирования", value=initial_base_prompt)

    enable_processing = st.sidebar.checkbox("Добавить обработку транскрибации", value=False)
    processing_prompt = st.sidebar.text_area("Промпт для обработки транскрибации", value=initial_processing_prompt)

    ACCESS_TOKEN = st.secrets["HF_TOKEN"]

    uploaded_file = st.file_uploader("Загрузите аудиофайл", type=["mp4", "wav", "m4a"])

    if uploaded_file is not None:
        file_name = uploaded_file.name

        if 'file_name' not in st.session_state or st.session_state.file_name != file_name:
            st.session_state.transcript = ''
            st.session_state.file_name = file_name

            
        st.audio(uploaded_file)
        file_extension = uploaded_file.name.split(".")[-1]  # Получаем расширение файла
        temp_file_path = f"temp_file.{file_extension}"  # Создаем временное имя файла с правильным расширением
    
        with open(temp_file_path, "wb") as f:
            f.write(uploaded_file.getbuffer())

        get_gpu_memory_info()

        if 'transcript' not in st.session_state or st.session_state.transcript == '':
    
            start_time = time.time()
            with st.spinner('Транскрибируем...'):
                # Load model
                model = whisperx.load_model(os.getenv('WHISPER_MODEL_SIZE'), device, compute_type=compute_type)
                # Load and transcribe audio
                audio = whisperx.load_audio(temp_file_path)
                result = model.transcribe(audio, batch_size=batch_size, language="ru")
                print('Transcribed, now aligning')
        
                model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
                result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
                print('Aligned, now diarizing')
        
                diarize_model = whisperx.DiarizationPipeline(use_auth_token=st.secrets["HF_TOKEN"], device=device)
                diarize_segments = diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
                result_diar = whisperx.assign_word_speakers(diarize_segments, result)
        
            transcript = convert_segments_object_to_text_simple(result_diar)
            st.session_state.transcript = transcript
            end_time = time.time()  # Конец отсчета времени
            total_time = end_time - start_time
            print(f'Полный процесс транскрипции занял {total_time:.2f} секунд')
        else:
            
            transcript = st.session_state.transcript
            
        st.write("Результат транскрибации:")
        st.text(transcript)

        if (llm == 'Сбер'):
            access_token = get_access_token()
    
        if (enable_processing):
            with st.spinner('Обрабатываем транскрибацию...'):

                if (llm == 'Сбер'):
                    number_of_tokens = get_number_of_tokens(transcript, access_token, llm_model)
                    print('Количество токенов в транскрибации: ' + str(number_of_tokens))
                    transcript = process_transcribation_with_gigachat(processing_prompt, transcript, number_of_tokens + 1000, access_token, llm_model)
                    print(transcript)
                    
                elif (llm == 'OpenAI'):
                    transcript = process_transcribation_with_assistant(processing_prompt, transcript)
                    print(transcript)

                else:
                    st.write("На данный момент обработка транскрибации не поддерживается этой моделью.")
    
        with st.spinner('Резюмируем...'):
            if (llm == 'Сбер'):
                summary_answer = get_completion_from_gigachat(base_prompt + transcript, 1024, access_token, llm_model)
            elif (llm == 'OpenAI'):
                summary_answer = get_completion_from_openai(base_prompt + transcript, llm_model, 1024)
            elif (llm == 'Qwen'):
                torch.cuda.empty_cache()
                from qwen import respond
                summary_answer = respond(base_prompt + transcript)
                
            st.write("Результат резюмирования:")
            st.text(summary_answer)