Spaces:
Build error
Build error
File size: 4,764 Bytes
a94d4a4 7801f60 a94d4a4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import json
import random
import numpy as np
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
st.set_page_config(page_title="MiniMind-V1")
st.title("MiniMind-V1")
model_id = "coffeecat304/minimind-v1"
@st.cache_resource
def load_model_tokenizer():
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
use_fast=False,
trust_remote_code=True
)
model = model.eval()
# generation_config = GenerationConfig.from_pretrained(model_id)
# return model, tokenizer, generation_config
return model, tokenizer, None
def clear_chat_messages():
del st.session_state.messages
del st.session_state.chat_messages
def init_chat_messages():
with st.chat_message("assistant", avatar='🤖'):
st.markdown("我是MiniMind,很高兴为您服务😄 \n"
"注:所有AI生成内容的准确性和立场无法保证,不代表我们的态度或观点。")
if "messages" in st.session_state:
for message in st.session_state.messages:
avatar = "🫡" if message["role"] == "user" else "🤖"
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
else:
st.session_state.messages = []
st.session_state.chat_messages = []
return st.session_state.messages
st.sidebar.title("设定调整")
st.session_state.history_chat_num = st.sidebar.slider("携带历史对话条数", 0, 6, 0, step=2)
st.session_state.max_new_tokens = st.sidebar.slider("最大输入/生成长度", 256, 768, 512, step=1)
st.session_state.top_k = st.sidebar.slider("top_k", 0, 16, 14, step=1)
st.session_state.temperature = st.sidebar.slider("temperature", 0.3, 1.3, 0.5, step=0.01)
def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
model, tokenizer, generation_config = load_model_tokenizer()
messages = init_chat_messages()
if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"):
with st.chat_message("user", avatar='🧑💻'):
st.markdown(prompt)
messages.append({"role": "user", "content": prompt})
st.session_state.chat_messages.append({"role": "user", "content": '请问,' + prompt + '?'})
with st.chat_message("assistant", avatar='🤖'):
placeholder = st.empty()
# Generate a random seed
random_seed = random.randint(0, 2 ** 32 - 1)
setup_seed(random_seed)
new_prompt = tokenizer.apply_chat_template(
st.session_state.chat_messages[-(st.session_state.history_chat_num + 1):],
tokenize=False,
add_generation_prompt=True
)[-(st.session_state.max_new_tokens - 1):]
x = tokenizer(new_prompt).data['input_ids']
x = (torch.tensor(x, dtype=torch.long)[None, ...])
with torch.no_grad():
res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=st.session_state.max_new_tokens,
temperature=st.session_state.temperature,
top_k=st.session_state.top_k, stream=True)
try:
y = next(res_y)
except StopIteration:
return
while y != None:
answer = tokenizer.decode(y[0].tolist())
if answer and answer[-1] == '�':
try:
y = next(res_y)
except:
break
continue
if not len(answer):
try:
y = next(res_y)
except:
break
continue
placeholder.markdown(answer)
try:
y = next(res_y)
except:
break
assistant_answer = answer.replace(new_prompt, "")
messages.append({"role": "assistant", "content": assistant_answer})
st.session_state.chat_messages.append({"role": "assistant", "content": assistant_answer})
st.button("清空对话", on_click=clear_chat_messages)
if __name__ == "__main__":
main()
|