Diezu commited on
Commit
76d56a2
·
verified ·
1 Parent(s): b6c7749

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
4
+ from peft import PeftModel
5
+ import numpy as np
6
+ import pyaudio
7
+
8
+ # Tải mô hình
9
+ @st.cache_resource
10
+ def load_model():
11
+ base_model_id = "openai/whisper-tiny"
12
+ adapter_id = "longhoang2112/whisper-turbo-fine-tuning-adapters"
13
+
14
+ processor = WhisperProcessor.from_pretrained(base_model_id)
15
+ model = WhisperForConditionalGeneration.from_pretrained(base_model_id)
16
+ try:
17
+ model = PeftModel.from_pretrained(model, adapter_id)
18
+ model.set_active_adapters(adapter_id)
19
+ except:
20
+ st.warning("Adapter loading failed. Using base model.")
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ model.to(device)
24
+ return processor, model, device
25
+
26
+ processor, model, device = load_model()
27
+
28
+ # Ghi âm
29
+ def record_audio(duration=5, sample_rate=16000):
30
+ CHUNK = 1024
31
+ FORMAT = pyaudio.paFloat32
32
+ CHANNELS = 1
33
+ p = pyaudio.PyAudio()
34
+ stream = p.open(format=FORMAT, channels=CHANNELS, rate=sample_rate, input=True, frames_per_buffer=CHUNK)
35
+
36
+ st.write(f"Đang ghi âm... ({duration} giây)")
37
+ frames = []
38
+ for _ in range(0, int(sample_rate / CHUNK * duration)):
39
+ data = stream.read(CHUNK)
40
+ frames.append(np.frombuffer(data, dtype=np.float32))
41
+
42
+ stream.stop_stream()
43
+ stream.close()
44
+ p.terminate()
45
+ return np.concatenate(frames), sample_rate
46
+
47
+ # Giao diện
48
+ st.title("Whisper Turbo với Adapter")
49
+ duration = st.slider("Thời gian ghi âm (giây):", 1, 10, 5)
50
+
51
+ if st.button("Ghi âm"):
52
+ audio, sample_rate = record_audio(duration)
53
+ input_features = processor(audio, sampling_rate=sample_rate, return_tensors="pt").input_features.to(device)
54
+ with torch.no_grad():
55
+ predicted_ids = model.generate(input_features)
56
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
57
+ st.write("**Kết quả:**", transcription)