yokomachi commited on
Commit
0702cff
·
verified ·
1 Parent(s): 8f2780f

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +87 -49
  3. requirements.txt +11 -6
README.md CHANGED
@@ -8,7 +8,7 @@ sdk_version: 1.43.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: This is simple "cat"bot.
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: This is a simple "cat"bot.
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,6 +1,27 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  # ページ設定
6
  st.set_page_config(
@@ -64,15 +85,19 @@ CAT_EXAMPLES = """
64
  """
65
 
66
  @st.cache_resource
67
- def load_model():
68
- """モデルをロードする関数(キャッシュ付き)"""
69
- # Hugging Faceからモデルをロード(アップロードしたモデル名に置き換えてください)
70
- model_path = "yokomachi/rinnya" # あなたのHugging Faceユーザー名に置き換えてください
71
 
72
  # トークナイザーとモデルをロード
73
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
74
  tokenizer.do_lower_case = True # rinnaモデル用の設定
75
 
 
 
 
 
76
  # モデルをロード
77
  model = AutoModelForCausalLM.from_pretrained(model_path)
78
 
@@ -80,11 +105,53 @@ def load_model():
80
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
  model.to(device)
82
 
83
- # パディングトークン設定
84
- if tokenizer.pad_token is None:
85
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- return tokenizer, model, device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  def extract_cat_response(generated_text):
90
  """生成されたテキストから猫の応答部分を抽出する関数"""
@@ -111,54 +178,29 @@ def post_process_response(response):
111
 
112
  return response
113
 
114
- def generate_cat_response(tokenizer, model, device, user_input):
115
- """猫の応答を生成する関数"""
116
- # プロンプトを作成
117
- prompt = f"""
118
- {CAT_PERSONALITY}
119
-
120
- 以下は猫と人間の会話例です:
121
- {CAT_EXAMPLES}
122
-
123
- 人間: {user_input}
124
- 猫:"""
125
-
126
- # 入力をトークナイズ
127
- inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
128
 
129
  # 応答を生成
130
- with torch.no_grad():
131
- outputs = model.generate(
132
- inputs,
133
- max_new_tokens=50,
134
- temperature=0.7,
135
- top_p=0.9,
136
- top_k=40,
137
- repetition_penalty=1.2,
138
- do_sample=True,
139
- pad_token_id=tokenizer.pad_token_id,
140
- eos_token_id=tokenizer.eos_token_id,
141
- no_repeat_ngram_size=3
142
- )
143
-
144
- # 生成されたテキストをデコード
145
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
146
 
147
  # 応答を抽出
148
  response = extract_cat_response(generated_text)
149
 
150
- # 応答を後処理(最小限)
151
  response = post_process_response(response)
152
 
153
  return response
154
 
155
  # アプリのタイトルと説明
156
- st.title("🐈catbot")
157
  st.markdown("""
158
  猫とじゃれあうチャットボット
159
  """)
160
 
161
-
162
  # セッション状態の初期化
163
  if "messages" not in st.session_state:
164
  st.session_state.messages = []
@@ -174,7 +216,7 @@ for message in st.session_state.messages:
174
 
175
  # モデルのロード(初回のみ実行され、その後はキャッシュから取得)
176
  try:
177
- tokenizer, model, device = load_model()
178
  model_loaded = True
179
  except Exception as e:
180
  st.error(f"モデルのロード中にエラーが発生しました: {e}")
@@ -194,13 +236,9 @@ if prompt := st.chat_input("猫に話しかけてみよう"):
194
  with st.chat_message("assistant", avatar="🐈"):
195
  with st.spinner("猫が考え中..."):
196
  try:
197
- response = generate_cat_response(tokenizer, model, device, prompt)
198
  st.markdown(response)
199
 
200
- # 猫の画像をランダムに表示(オプション)
201
- if "ニャッ" in response or "ニャー" in response:
202
- st.image("https://placekitten.com/300/200", caption="にゃー")
203
-
204
  # 応答を履歴に追加
205
  st.session_state.messages.append({"role": "assistant", "content": response})
206
  except Exception as e:
@@ -216,4 +254,4 @@ if prompt := st.chat_input("猫に話しかけてみよう"):
216
  # 会話をクリアするボタン
217
  if st.button("会話をクリア"):
218
  st.session_state.messages = []
219
- st.rerun()
 
1
  import streamlit as st
2
  import torch
3
+ import nest_asyncio
4
+ import os
5
+ from dotenv import load_dotenv
6
+ from langchain_huggingface import HuggingFacePipeline
7
+ from langchain_core.prompts import PromptTemplate
8
+ from langchain_core.runnables import RunnablePassthrough
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
10
+
11
+ # .envファイルから環境変数を読み込む
12
+ load_dotenv()
13
+
14
+ # LangSmith関連の環境変数を設定
15
+ os.environ["LANGSMITH_TRACING"] = os.getenv("LANGSMITH_TRACING")
16
+ os.environ["LANGSMITH_ENDPOINT"] = os.getenv("LANGSMITH_ENDPOINT")
17
+ os.environ["LANGSMITH_API_KEY"] = os.getenv("LANGSMITH_API_KEY")
18
+ os.environ["LANGSMITH_PROJECT"] = os.getenv("LANGSMITH_PROJECT")
19
+
20
+ # nest_asyncioを適用
21
+ nest_asyncio.apply()
22
+
23
+ # torch.classes.__path__を空のリストに設定
24
+ torch.classes.__path__ = []
25
 
26
  # ページ設定
27
  st.set_page_config(
 
85
  """
86
 
87
  @st.cache_resource
88
+ def load_langchain_model():
89
+ """LangChainモデルをロードする関数(キャッシュ付き)"""
90
+ # Hugging Faceからモデルをロード
91
+ model_path = "yokomachi/rinnya"
92
 
93
  # トークナイザーとモデルをロード
94
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
95
  tokenizer.do_lower_case = True # rinnaモデル用の設定
96
 
97
+ # パディングトークンの設定
98
+ if tokenizer.pad_token is None:
99
+ tokenizer.pad_token = tokenizer.eos_token
100
+
101
  # モデルをロード
102
  model = AutoModelForCausalLM.from_pretrained(model_path)
103
 
 
105
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
  model.to(device)
107
 
108
+ # Hugging Face pipeline作成
109
+ # Torchのエラーを回避するために設定を修正
110
+ text_generation_pipeline = pipeline(
111
+ "text-generation",
112
+ model=model,
113
+ tokenizer=tokenizer,
114
+ max_new_tokens=50,
115
+ temperature=0.7,
116
+ top_p=0.9,
117
+ top_k=40,
118
+ repetition_penalty=1.2,
119
+ do_sample=True,
120
+ pad_token_id=tokenizer.pad_token_id,
121
+ eos_token_id=tokenizer.eos_token_id,
122
+ # no_repeat_ngram_sizeパラメータを削除(問題の原因となる可能性があるため)
123
+ )
124
+
125
+ # LangChain HuggingFacePipelineの作成
126
+ llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
127
 
128
+ # プロンプトテンプレートの作成
129
+ template = """
130
+ {cat_personality}
131
+
132
+ 以下は猫と人間の会話例です:
133
+ {cat_examples}
134
+
135
+ 人間: {user_input}
136
+ 猫:"""
137
+
138
+ prompt = PromptTemplate(
139
+ input_variables=["cat_personality", "cat_examples", "user_input"],
140
+ template=template
141
+ )
142
+
143
+ # 新しいRunnableSequenceの作成
144
+ chain = (
145
+ {
146
+ "cat_personality": lambda x: CAT_PERSONALITY,
147
+ "cat_examples": lambda x: CAT_EXAMPLES,
148
+ "user_input": RunnablePassthrough()
149
+ }
150
+ | prompt
151
+ | llm
152
+ )
153
+
154
+ return chain, device
155
 
156
  def extract_cat_response(generated_text):
157
  """生成されたテキストから猫の応答部分を抽出する関数"""
 
178
 
179
  return response
180
 
181
+ def generate_cat_response_with_langchain(chain, user_input):
182
+ """LangChainを使って猫の応答を生成する関数"""
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  # 応答を生成
185
+ result = chain.invoke(user_input)
186
+
187
+ # 結果から応答テキストを取得
188
+ generated_text = result
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  # 応答を抽出
191
  response = extract_cat_response(generated_text)
192
 
193
+ # 応答を後処理
194
  response = post_process_response(response)
195
 
196
  return response
197
 
198
  # アプリのタイトルと説明
199
+ st.title("🐈 catbot")
200
  st.markdown("""
201
  猫とじゃれあうチャットボット
202
  """)
203
 
 
204
  # セッション状態の初期化
205
  if "messages" not in st.session_state:
206
  st.session_state.messages = []
 
216
 
217
  # モデルのロード(初回のみ実行され、その後はキャッシュから取得)
218
  try:
219
+ chain, device = load_langchain_model()
220
  model_loaded = True
221
  except Exception as e:
222
  st.error(f"モデルのロード中にエラーが発生しました: {e}")
 
236
  with st.chat_message("assistant", avatar="🐈"):
237
  with st.spinner("猫が考え中..."):
238
  try:
239
+ response = generate_cat_response_with_langchain(chain, prompt)
240
  st.markdown(response)
241
 
 
 
 
 
242
  # 応答を履歴に追加
243
  st.session_state.messages.append({"role": "assistant", "content": response})
244
  except Exception as e:
 
254
  # 会話をクリアするボタン
255
  if st.button("会話をクリア"):
256
  st.session_state.messages = []
257
+ st.rerun()
requirements.txt CHANGED
@@ -1,7 +1,12 @@
 
 
 
 
1
  streamlit>=1.28.0
2
- torch>=2.0.0
3
- transformers>=4.30.0
4
- huggingface-hub>=0.16.0
5
- protobuf>=3.20.0
6
- accelerate>=0.20.0
7
- sentencepiece>=0.1.99
 
 
1
+ huggingface-hub>=0.19.4
2
+ torch>=2.0.1
3
+ transformers>=4.30.2
4
+ sentencepiece>=0.1.99
5
  streamlit>=1.28.0
6
+ protobuf>=3.20.3
7
+ accelerate>=0.20.3
8
+ langchain>=0.1.0
9
+ langchain-community>=0.0.10
10
+ langchain-huggingface>=0.0.2
11
+ python-dotenv>=1.0.0
12
+ nest-asyncio>=1.5.6