buchi-stdesign commited on
Commit
559724d
·
verified ·
1 Parent(s): 8afc238

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -49
app.py CHANGED
@@ -1,58 +1,90 @@
1
  import streamlit as st
2
- import subprocess
3
- import sys
4
- from pathlib import Path
5
- import base64
6
- import io
7
- import soundfile as sf
8
  import numpy as np
 
 
 
9
  from huggingface_hub import hf_hub_download
 
 
 
10
 
11
- # AIvtuber.py機能をイ
12
- sys.path.append(str(Path.cwd()))
13
- from AIvtuber import chat_session, tts_to_wav
 
 
 
 
 
 
 
 
 
 
14
 
15
- # TTSモデルの設定
16
- model_file = "https://huggingface.co/buchi-stdesign/3DAItuber-model/resolve/main/Anneli_e116_s32000.safetensors"
17
- config_file = "https://huggingface.co/buchi-stdesign/3DAItuber-model/resolve/main/config.json"
18
- style_file = "https://huggingface.co/buchi-stdesign/3DAItuber-model/resolve/main/style_vectors.npy"
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- def generate_audio(text):
21
- tts = TTSModel(
22
- model_path=model_file,
23
- config_path=config_file,
24
- style_vec_path=style_file,
25
- device="cpu" # クラウド環境ではCPUを使用
26
- )
27
- sr, wav = tts.infer(text=text, length=0.85)
28
 
29
- # 音声ータをBase64エンコード
30
- buffer = io.BytesIO()
31
- sf.write(buffer, wav, sr, format='WAV')
32
- audio_base64 = base64.b64encode(buffer.getvalue()).decode()
33
 
34
- return audio_base64
35
-
36
- st.title("AI VTuber Chat")
37
-
38
- # Vroid Hubをiframeで埋め込み
39
- st.components.v1.iframe(
40
- "https://hub.vroid.com/",
41
- height=600,
42
- scrolling=True
43
- )
44
-
45
- # チャットインターフェース
46
- user_input = st.text_input("メッセージを入力してください:")
47
-
48
- if user_input:
49
- # AIvtuber.pyの機能を使用
50
- resp = chat_session.send_message(user_input)
51
- st.write("AI:", resp.text)
52
 
53
- # 音声生成と再生
54
- sr, wav = tts_to_wav(resp.text)
55
- buffer = io.BytesIO()
56
- sf.write(buffer, wav, sr, format='WAV')
57
- audio_base64 = base64.b64encode(buffer.getvalue()).decode()
58
- st.audio(f"data:audio/wav;base64,{audio_base64}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from diffusers import StableBeluga2Pipeline
 
 
 
 
4
  import numpy as np
5
+ import soundfile as sf
6
+ import io
7
+ import base64
8
  from huggingface_hub import hf_hub_download
9
+ import asyncio
10
+ import time
11
+ from concurrent.futures import ThreadPoolExecutor
12
 
13
+ # モデルダウドと初期化
14
+ @st.cache_resource
15
+ def load_model():
16
+ model_id = "stabilityai/stable-beluga-2"
17
+ pipe = StableBeluga2Pipeline.from_pretrained(
18
+ model_id,
19
+ torch_dtype=torch.float16,
20
+ use_safetensors=True,
21
+ variant="fp16"
22
+ )
23
+ if torch.cuda.is_available():
24
+ pipe = pipe.to("cuda")
25
+ return pipe
26
 
27
+ # 音声生成を非同期で実行する関数
28
+ def generate_audio_async(pipe, text, progress_bar):
29
+ try:
30
+ # 音声生成
31
+ audio = pipe(
32
+ text,
33
+ num_inference_steps=50,
34
+ guidance_scale=7.5
35
+ ).audio[0]
36
+
37
+ # プログレスバーを更新
38
+ progress_bar.progress(1.0)
39
+
40
+ return audio
41
+ except Exception as e:
42
+ st.error(f"音声生成中にエラーが発生しました: {str(e)}")
43
+ return None
44
 
45
+ # メインアプリケーション
46
+ def main():
47
+ st.title("AI VTuber チャット")
 
 
 
 
 
48
 
49
+ # ルの読み込み
50
+ pipe = load_model()
 
 
51
 
52
+ # チャット履歴の初期化
53
+ if "messages" not in st.session_state:
54
+ st.session_state.messages = []
55
+
56
+ # チャット入力
57
+ user_input = st.chat_input("メッセージを入力してください")
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ if user_input:
60
+ # ユーザーメッセージを表示
61
+ st.session_state.messages.append({"role": "user", "content": user_input})
62
+
63
+ # 音声生成の進捗バー
64
+ progress_bar = st.progress(0.0)
65
+ status_text = st.empty()
66
+ status_text.text("音声を生成中...")
67
+
68
+ # 音声生成を非同期で実行
69
+ with ThreadPoolExecutor() as executor:
70
+ future = executor.submit(generate_audio_async, pipe, user_input, progress_bar)
71
+ audio = future.result()
72
+
73
+ if audio is not None:
74
+ # 音声データを保存
75
+ audio_path = "generated_audio.wav"
76
+ sf.write(audio_path, audio.cpu().numpy(), 44100)
77
+
78
+ # 音声プレーヤーを表示
79
+ st.audio(audio_path)
80
+
81
+ # ステータステキストを更新
82
+ status_text.text("音声生成完了!")
83
+
84
+ # チャット履歴を表示
85
+ for message in st.session_state.messages:
86
+ with st.chat_message(message["role"]):
87
+ st.write(message["content"])
88
+
89
+ if __name__ == "__main__":
90
+ main()