| import streamlit as st |
| import random |
| from langchain_community.llms import HuggingFaceHub |
| from langchain_community.embeddings import SentenceTransformerEmbeddings |
| from langchain_community.vectorstores import FAISS |
| from datasets import load_dataset |
| from opencc import OpenCC |
|
|
| |
| |
| if "data_list" not in st.session_state: |
| st.session_state.data_list = [] |
| st.session_state.answer_list = [] |
|
|
| if not st.session_state.data_list: |
| try: |
| with st.spinner("正在读取数据库..."): |
| converter = OpenCC('tw2s') |
| dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese") |
| data_list = [] |
| answer_list = [] |
| for example in dataset["train"]: |
| converted_answer = converter.convert(example["Answer"]) |
| converted_question = converter.convert(example["Question"]) |
| answer_list.append(converted_answer) |
| data_list.append({"Question": converted_question, "Answer": converted_answer}) |
| st.session_state.answer_list = answer_list |
| st.session_state.data_list = data_list |
| st.success("数据库读取完成!") |
| print("数据库读取完成!") |
| except Exception as e: |
| st.error(f"读取数据集失败:{e}") |
| st.stop() |
|
|
| |
| if "vector_created" not in st.session_state: |
| st.session_state.vector_created = False |
| if not st.session_state.vector_created: |
| try: |
| with st.spinner("正在构建向量数据库..."): |
| |
| |
| |
| |
| st.session_state.embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2") |
| st.session_state.db = FAISS.from_texts(st.session_state.answer_list, st.session_state.embeddings) |
| st.success("向量数据库构建完成!") |
| print("向量数据库构建完成!") |
| except Exception as e: |
| st.error(f"向量数据库构建失败:{e}") |
| st.stop() |
| st.session_state.vector_created = True |
|
|
| |
| if "repo_id" not in st.session_state: |
| st.session_state.repo_id = '' |
| if "temperature" not in st.session_state: |
| st.session_state.temperature = '' |
| if "max_length" not in st.session_state: |
| st.session_state.max_length = '' |
| def answer_question(repo_id, temperature, max_length, question): |
| |
| if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length: |
| try: |
| with st.spinner("正在初始化 Gemma 模型..."): |
| st.session_state.llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length}) |
| st.success("Gemma 模型初始化完成!") |
| print("Gemma 模型初始化完成!") |
| st.session_state.repo_id = repo_id |
| st.session_state.temperature = temperature |
| st.session_state.max_length = max_length |
| except Exception as e: |
| st.error(f"Gemma 模型加载失败:{e}") |
| st.stop() |
|
|
| |
| try: |
| with st.spinner("正在筛选本地数据集..."): |
| question_embedding = st.session_state.embeddings.embed_query(question) |
| question_embedding_str = " ".join(map(str, question_embedding)) |
| |
| docs_and_scores = st.session_state.db.similarity_search_with_score(question_embedding_str) |
|
|
| context = "\n".join([doc.page_content for doc, _ in docs_and_scores]) |
| print('context: ' + context) |
|
|
| prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}" |
| print('prompt: ' + prompt) |
|
|
| st.success("本地数据集筛选完成!") |
| print("本地数据集筛选完成!") |
|
|
| with st.spinner("正在生成答案..."): |
| answer = st.session_state.llm.invoke(prompt) |
| |
| answer = answer.replace(prompt, "").strip() |
| st.success("答案已经生成!") |
| print("答案已经生成!") |
| return {"prompt": prompt, "answer": answer} |
| except Exception as e: |
| st.error(f"问答过程出错:{e}") |
| return {"prompt": "", "answer": "An error occurred during the answering process."} |
|
|
| |
| st.title("進擊的巨人 知识库问答系统") |
|
|
| col1, col2 = st.columns(2) |
| with col1: |
| gemma = st.selectbox("repo-id", ("google/gemma-2-9b-it", "google/gemma-2-2b-it", "google/recurrentgemma-2b-it"), 2) |
| with col2: |
| temperature = st.number_input("temperature", value=1.0) |
| max_length = st.number_input("max_length", value=1024) |
|
|
| st.divider() |
|
|
| def generate_answer(repo_id, temperature, max_length, question): |
| result = answer_question(repo_id, float(temperature), int(max_length), question) |
| print('prompt: ' + result["prompt"]) |
| print('answer: ' + result["answer"]) |
| st.write("参考文字:") |
| st.markdown(result["prompt"]) |
| st.write("生成答案:") |
| st.write(result["answer"]) |
|
|
| col3, col4 = st.columns(2) |
| with col3: |
| if st.button("使用原数据集中的随机问题"): |
| dataset_size = len(st.session_state.data_list) |
| random_index = random.randint(0, dataset_size - 1) |
| |
| random_question = st.session_state.data_list[random_index]["Question"] |
| origin_answer = st.session_state.data_list[random_index]["Answer"] |
| print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question) |
| print('origin_answer: ' + origin_answer) |
|
|
| st.write("随机问题:") |
| st.write(random_question) |
| st.write("原始答案:") |
| st.write(origin_answer) |
| generate_answer(gemma, float(temperature), int(max_length), random_question) |
|
|
| with col4: |
| question = st.text_area("请输入问题", "《进击的巨人》中都有哪些主要角色?") |
| if st.button("提交输入的问题"): |
| if not question: |
| st.warning("请输入问题!") |
| else: |
| generate_answer(gemma, float(temperature), int(max_length), question) |
|
|