buchi-stdesign commited on
Commit
ca78aa1
·
verified ·
1 Parent(s): f1a6899

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +25 -96
  2. requirements.txt +4 -12
app.py CHANGED
@@ -1,105 +1,34 @@
1
  import streamlit as st
2
- import torch
3
- from diffusers import StableBeluga2Pipeline
4
  import numpy as np
5
  import soundfile as sf
6
  import io
7
- import base64
8
- from huggingface_hub import hf_hub_download
9
- import asyncio
10
- import time
11
- from concurrent.futures import ThreadPoolExecutor
12
- import os
13
 
14
- # モデルのダウンロードと初期化
15
  @st.cache_resource
16
  def load_model():
17
- try:
18
- # モデルのキャッシュディレクトリを設定
19
- cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
20
- os.makedirs(cache_dir, exist_ok=True)
21
-
22
- model_id = "stabilityai/stable-beluga-2"
23
- pipe = StableBeluga2Pipeline.from_pretrained(
24
- model_id,
25
- torch_dtype=torch.float16,
26
- use_safetensors=True,
27
- variant="fp16",
28
- cache_dir=cache_dir
29
- )
30
- if torch.cuda.is_available():
31
- pipe = pipe.to("cuda")
32
- return pipe
33
- except Exception as e:
34
- st.error(f"モデルの読み込み中にエラーが発生しました: {str(e)}")
35
- return None
36
-
37
- # 音声生成を非同期で実行する関数
38
- def generate_audio_async(pipe, text, progress_bar):
39
- try:
40
- # 音声生成
41
- audio = pipe(
42
- text,
43
- num_inference_steps=50,
44
- guidance_scale=7.5
45
- ).audio[0]
46
-
47
- # プログレスバーを更新
48
- progress_bar.progress(1.0)
49
-
50
- return audio
51
- except Exception as e:
52
- st.error(f"音声生成中にエラーが発生しました: {str(e)}")
53
- return None
54
 
55
- # メインアプリケーション
56
- def main():
57
- st.title("AI VTuber チャット")
58
-
59
- # モデルの読み込み
60
- with st.spinner("モデルを読み込み中..."):
61
- pipe = load_model()
62
-
63
- if pipe is None:
64
- st.error("モデルの読み込みに失敗しました。アプリケーションを再起動してください。")
65
- return
66
-
67
- # チャット履歴の初期化
68
- if "messages" not in st.session_state:
69
- st.session_state.messages = []
70
-
71
- # チャット入力
72
- user_input = st.chat_input("メッセージを入力してください")
73
-
74
- if user_input:
75
- # ユーザーメッセージを表示
76
- st.session_state.messages.append({"role": "user", "content": user_input})
77
-
78
- # 音声生成の進捗バー
79
- progress_bar = st.progress(0.0)
80
- status_text = st.empty()
81
- status_text.text("音声を生成中...")
82
-
83
- # 音声生成を非同期で実行
84
- with ThreadPoolExecutor() as executor:
85
- future = executor.submit(generate_audio_async, pipe, user_input, progress_bar)
86
- audio = future.result()
87
-
88
- if audio is not None:
89
- # 音声データを保存
90
- audio_path = "generated_audio.wav"
91
- sf.write(audio_path, audio.cpu().numpy(), 44100)
92
-
93
- # 音声プレーヤーを表示
94
- st.audio(audio_path)
95
-
96
- # ステータステキストを更新
97
- status_text.text("音声生成完了!")
98
-
99
- # チャット履歴を表示
100
- for message in st.session_state.messages:
101
- with st.chat_message(message["role"]):
102
- st.write(message["content"])
103
 
104
- if __name__ == "__main__":
105
- main()
 
 
 
 
 
1
  import streamlit as st
 
 
2
  import numpy as np
3
  import soundfile as sf
4
  import io
5
+ from style_bert_vits2 import TTSModel
6
+
7
+ # モデルファイルのパス
8
+ MODEL_PATH = "Anneli_e116_s32000.safetensors"
9
+ CONFIG_PATH = "config.json"
10
+ STYLE_VEC_PATH = "style_vectors.npy"
11
 
 
12
  @st.cache_resource
13
  def load_model():
14
+ tts = TTSModel(
15
+ model_path=MODEL_PATH,
16
+ config_path=CONFIG_PATH,
17
+ style_vec_path=STYLE_VEC_PATH,
18
+ device="cpu" # 無料枠はCPUのみ
19
+ )
20
+ return tts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def generate_audio(text, tts):
23
+ sr, wav = tts.infer(text=text, length=0.85)
24
+ buffer = io.BytesIO()
25
+ sf.write(buffer, wav, sr, format='WAV')
26
+ buffer.seek(0)
27
+ return buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ st.title("AI VTuber チャット(SBV2版)")
30
+ tts = load_model()
31
+ user_input = st.text_input("メッセージを入力してください:")
32
+ if user_input:
33
+ audio_fp = generate_audio(user_input, tts)
34
+ st.audio(audio_fp, format="audio/wav")
requirements.txt CHANGED
@@ -1,12 +1,4 @@
1
- torch>=2.0.0
2
- torchaudio>=2.0.0
3
- diffusers>=0.33.1
4
- transformers>=4.52.4
5
- accelerate>=1.7.0
6
- streamlit>=1.32.0
7
- numpy>=1.24.0
8
- soundfile>=0.12.1
9
- huggingface-hub>=0.27.0
10
- google-generativeai
11
- pytchat
12
- style-bert-vits2
 
1
+ streamlit
2
+ numpy
3
+ soundfile
4
+ style-bert-vits2