Spaces:
Build error
Build error
Commit ·
1559fe0
1
Parent(s): 728a621
Update transformers to 4.39.3 and optimize model loading
Browse files- app.py +9 -7
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -36,6 +36,7 @@ model_name_or_path = "xiaoxishui/internlm2_5-7b-chat"
|
|
| 36 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 37 |
print(f"Using device: {device}")
|
| 38 |
|
|
|
|
| 39 |
@dataclass
|
| 40 |
class GenerationConfig:
|
| 41 |
# this config is used for chat to provide more diversity
|
|
@@ -187,7 +188,10 @@ def on_btn_click():
|
|
| 187 |
def load_model():
|
| 188 |
model = (AutoModelForCausalLM.from_pretrained(
|
| 189 |
model_name_or_path,
|
| 190 |
-
trust_remote_code=True
|
|
|
|
|
|
|
|
|
|
| 191 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
|
| 192 |
trust_remote_code=True)
|
| 193 |
return model, tokenizer
|
|
@@ -210,17 +214,16 @@ def prepare_generation_config():
|
|
| 210 |
return generation_config
|
| 211 |
|
| 212 |
|
| 213 |
-
user_prompt = '
|
| 214 |
-
robot_prompt = '
|
| 215 |
-
cur_query_prompt = '
|
| 216 |
-
<|im_start|>assistant\n'
|
| 217 |
|
| 218 |
|
| 219 |
def combine_history(prompt):
|
| 220 |
messages = st.session_state.messages
|
| 221 |
meta_instruction = ('You are a helpful, honest, '
|
| 222 |
'and harmless AI assistant.')
|
| 223 |
-
total_prompt = f'<s>
|
| 224 |
for message in messages:
|
| 225 |
cur_content = message['content']
|
| 226 |
if message['role'] == 'user':
|
|
@@ -293,4 +296,3 @@ def main():
|
|
| 293 |
|
| 294 |
if __name__ == '__main__':
|
| 295 |
main()
|
| 296 |
-
|
|
|
|
| 36 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 37 |
print(f"Using device: {device}")
|
| 38 |
|
| 39 |
+
|
| 40 |
@dataclass
|
| 41 |
class GenerationConfig:
|
| 42 |
# this config is used for chat to provide more diversity
|
|
|
|
| 188 |
def load_model():
|
| 189 |
model = (AutoModelForCausalLM.from_pretrained(
|
| 190 |
model_name_or_path,
|
| 191 |
+
trust_remote_code=True,
|
| 192 |
+
use_cache=False, # 禁用 KV 缓存
|
| 193 |
+
torch_dtype=torch.bfloat16,
|
| 194 |
+
device_map="auto")).cuda()
|
| 195 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
|
| 196 |
trust_remote_code=True)
|
| 197 |
return model, tokenizer
|
|
|
|
| 214 |
return generation_config
|
| 215 |
|
| 216 |
|
| 217 |
+
user_prompt = '👥\n{user}\n'
|
| 218 |
+
robot_prompt = '🤖\n{robot}\n'
|
| 219 |
+
cur_query_prompt = '👥\n{user}\n'
|
|
|
|
| 220 |
|
| 221 |
|
| 222 |
def combine_history(prompt):
|
| 223 |
messages = st.session_state.messages
|
| 224 |
meta_instruction = ('You are a helpful, honest, '
|
| 225 |
'and harmless AI assistant.')
|
| 226 |
+
total_prompt = f'<s>🤖\n{meta_instruction}\n'
|
| 227 |
for message in messages:
|
| 228 |
cur_content = message['content']
|
| 229 |
if message['role'] == 'user':
|
|
|
|
| 296 |
|
| 297 |
if __name__ == '__main__':
|
| 298 |
main()
|
|
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
streamlit>=1.8.0
|
| 2 |
-
transformers==4.
|
| 3 |
torch>=2.0.0
|
| 4 |
accelerate>=0.20.0
|
| 5 |
sentencepiece
|
|
|
|
| 1 |
streamlit>=1.8.0
|
| 2 |
+
transformers==4.39.3
|
| 3 |
torch>=2.0.0
|
| 4 |
accelerate>=0.20.0
|
| 5 |
sentencepiece
|