Spaces:
Running
Running
ver 0.7 for test
Browse files- app.py +869 -0
- requirements.txt +18 -0
app.py
ADDED
|
@@ -0,0 +1,869 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------
|
| 2 |
+
# Libraries
|
| 3 |
+
# --------------------------------------
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import gc # メモリ解放
|
| 7 |
+
import re # 正規表現で文章をクリーンアップ
|
| 8 |
+
|
| 9 |
+
# HuggingFace
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 12 |
+
|
| 13 |
+
# OpenAI
|
| 14 |
+
import openai
|
| 15 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
| 16 |
+
from langchain.chat_models import ChatOpenAI
|
| 17 |
+
|
| 18 |
+
# LangChain
|
| 19 |
+
from langchain.llms import HuggingFacePipeline
|
| 20 |
+
from transformers import pipeline
|
| 21 |
+
|
| 22 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 23 |
+
from langchain.chains import VectorDBQA
|
| 24 |
+
from langchain.vectorstores import Chroma
|
| 25 |
+
|
| 26 |
+
from langchain import PromptTemplate, ConversationChain
|
| 27 |
+
from langchain.chains.question_answering import load_qa_chain # QA Chat
|
| 28 |
+
from langchain.document_loaders import SeleniumURLLoader # URL取得
|
| 29 |
+
from langchain.docstore.document import Document # テキストをドキュメント化
|
| 30 |
+
# from langchain.memory import ConversationBufferWindowMemory # チャット履歴
|
| 31 |
+
from langchain.memory import ConversationSummaryBufferMemory # チャット履歴
|
| 32 |
+
|
| 33 |
+
from typing import Any
|
| 34 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 35 |
+
|
| 36 |
+
# Gradio
|
| 37 |
+
import gradio as gr
|
| 38 |
+
|
| 39 |
+
# PyPdf
|
| 40 |
+
from pypdf import PdfReader
|
| 41 |
+
|
| 42 |
+
# test
|
| 43 |
+
import langchain # (debug=Trueにするため)
|
| 44 |
+
|
| 45 |
+
# --------------------------------------
|
| 46 |
+
# ユーザ別セッションの変数値を記録するクラス
|
| 47 |
+
# (参考)https://blog.shikoan.com/gradio-state/
|
| 48 |
+
# --------------------------------------
|
| 49 |
+
class SessionState:
|
| 50 |
+
def __init__(self):
|
| 51 |
+
# Hugging Face
|
| 52 |
+
self.tokenizer = None
|
| 53 |
+
self.pipe = None
|
| 54 |
+
self.model = None
|
| 55 |
+
|
| 56 |
+
# LangChain
|
| 57 |
+
self.llm = None
|
| 58 |
+
self.embeddings = None
|
| 59 |
+
self.current_model = ""
|
| 60 |
+
self.current_embedding = ""
|
| 61 |
+
self.db = None # Vector DB
|
| 62 |
+
self.memory = None # Langchain Chat Memory
|
| 63 |
+
self.qa_chain = None # load_qa_chain
|
| 64 |
+
self.conversation_chain = None # ConversationChain
|
| 65 |
+
self.embedded_urls = []
|
| 66 |
+
|
| 67 |
+
# Apps
|
| 68 |
+
self.dialogue = [] # Recent Chat History for display
|
| 69 |
+
|
| 70 |
+
# --------------------------------------
|
| 71 |
+
# Empty Cache
|
| 72 |
+
# --------------------------------------
|
| 73 |
+
def cache_clear(self):
|
| 74 |
+
if torch.cuda.is_available():
|
| 75 |
+
torch.cuda.empty_cache() # GPU Memory Clear
|
| 76 |
+
|
| 77 |
+
gc.collect() # CPU Memory Clear
|
| 78 |
+
|
| 79 |
+
# --------------------------------------
|
| 80 |
+
# Clear Models (llm: llm model, embd: embeddings, db: vectordb)
|
| 81 |
+
# --------------------------------------
|
| 82 |
+
def clear_memory(self, llm=False, embd=False, db=False):
|
| 83 |
+
# DB
|
| 84 |
+
if db and self.db:
|
| 85 |
+
self.db.delete_collection()
|
| 86 |
+
self.db = None
|
| 87 |
+
self.embedded_urls = []
|
| 88 |
+
|
| 89 |
+
# Embeddings model
|
| 90 |
+
if llm or embd:
|
| 91 |
+
self.embeddings = None
|
| 92 |
+
self.current_embedding = ""
|
| 93 |
+
self.qa_chain = None
|
| 94 |
+
|
| 95 |
+
# LLM model
|
| 96 |
+
if llm:
|
| 97 |
+
self.llm = None
|
| 98 |
+
self.pipe = None
|
| 99 |
+
self.model = None
|
| 100 |
+
self.current_model = ""
|
| 101 |
+
self.tokenizer = None
|
| 102 |
+
self.memory = None
|
| 103 |
+
self.chat_history = [] # ←必要性を要検証
|
| 104 |
+
|
| 105 |
+
self.cache_clear()
|
| 106 |
+
|
| 107 |
+
# --------------------------------------
|
| 108 |
+
# Load Chat History as a list
|
| 109 |
+
# --------------------------------------
|
| 110 |
+
def load_chat_history(self) -> list:
|
| 111 |
+
chat_history = []
|
| 112 |
+
try:
|
| 113 |
+
chat_memory = self.memory.load_memory_variables({})['chat_history']
|
| 114 |
+
except KeyError:
|
| 115 |
+
return chat_history
|
| 116 |
+
|
| 117 |
+
# チャット履歴をペアごとに読み取る
|
| 118 |
+
for i in range(0, len(chat_memory), 2):
|
| 119 |
+
user_message = chat_memory[i].content
|
| 120 |
+
ai_message = ""
|
| 121 |
+
if i + 1 < len(chat_memory):
|
| 122 |
+
ai_message = chat_memory[i + 1].content
|
| 123 |
+
chat_history.append([user_message, ai_message])
|
| 124 |
+
return chat_history
|
| 125 |
+
|
| 126 |
+
# --------------------------------------
|
| 127 |
+
# 自作TextSplitter(テキストをLLMのトークン数内に分割)
|
| 128 |
+
# (参考)https://www.sato-susumu.com/entry/2023/04/30/131338
|
| 129 |
+
# → 「!」、「?」、「)」、「.」、「!」、「?」、「,」などを追加
|
| 130 |
+
# --------------------------------------
|
| 131 |
+
class JPTextSplitter(RecursiveCharacterTextSplitter):
|
| 132 |
+
def __init__(self, **kwargs: Any):
|
| 133 |
+
separators = ["\n\n", "\n", "。", "!", "?", ")","、", ".", "!", "?", ",", " ", ""]
|
| 134 |
+
super().__init__(separators=separators, **kwargs)
|
| 135 |
+
|
| 136 |
+
# チャンクの分割
|
| 137 |
+
chunk_size = 512
|
| 138 |
+
chunk_overlap = 35
|
| 139 |
+
|
| 140 |
+
text_splitter = JPTextSplitter(
|
| 141 |
+
chunk_size = chunk_size, # チャンクの最大文字数
|
| 142 |
+
chunk_overlap = chunk_overlap, # オーバーラップの最大文字数
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# --------------------------------------
|
| 146 |
+
# DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
|
| 147 |
+
# --------------------------------------
|
| 148 |
+
DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate"
|
| 149 |
+
DEEPL_API_KEY = "YOUR_DEEPL_API_KEY"
|
| 150 |
+
|
| 151 |
+
def deepl_memory(ss: SessionState) -> (SessionState):
|
| 152 |
+
if ss.current_model == "gpt-3.5-turbo":
|
| 153 |
+
# メモリから会話履歴を取得
|
| 154 |
+
user_message = ss.memory.chat_memory.messages[-1][0].content
|
| 155 |
+
ai_message = ss.memory.chat_memory.messages[-1][1].content
|
| 156 |
+
text = [user_message, ai_message]
|
| 157 |
+
|
| 158 |
+
# DeepL設定
|
| 159 |
+
params = {
|
| 160 |
+
"auth_key": DEEPL_API_KEY,
|
| 161 |
+
"text": text,
|
| 162 |
+
"target_lang": "EN",
|
| 163 |
+
"source_lang": "JA"
|
| 164 |
+
}
|
| 165 |
+
request = requests.post(DEEPL_API_ENDPOINT, data=params)
|
| 166 |
+
request.raise_for_status() # 応答のステータスコードがエラーの場合は例外を発生させます。
|
| 167 |
+
response = request.json()
|
| 168 |
+
|
| 169 |
+
# JSONから翻訳文を取得
|
| 170 |
+
user_message = response["translations"][0]["text"]
|
| 171 |
+
ai_message = response["translations"][1]["text"]
|
| 172 |
+
|
| 173 |
+
# memoryの最後の会話を削除し、翻訳文を追加
|
| 174 |
+
ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-1]
|
| 175 |
+
ss.memory.chat_memory.add_user_message(user_message)
|
| 176 |
+
ss.memory.chat_memory.add_ai_message(ai_message)
|
| 177 |
+
|
| 178 |
+
return ss
|
| 179 |
+
|
| 180 |
+
# --------------------------------------
|
| 181 |
+
# LangChain カスタムプロンプト各種
|
| 182 |
+
# llama tokenizer
|
| 183 |
+
# https://belladoreai.github.io/llama-tokenizer-js/example-demo/build/
|
| 184 |
+
|
| 185 |
+
# OpenAI tokenizer
|
| 186 |
+
# https://platform.openai.com/tokenizer
|
| 187 |
+
# --------------------------------------
|
| 188 |
+
|
| 189 |
+
# --------------------------------------
|
| 190 |
+
# Conversation Chain Template
|
| 191 |
+
# --------------------------------------
|
| 192 |
+
|
| 193 |
+
# Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
|
| 194 |
+
sys_chat_message = """
|
| 195 |
+
The following is a conversation between an AI concierge and a customer.
|
| 196 |
+
The AI understands what the customer wants to know from the conversation history and the latest question,
|
| 197 |
+
and gives many specific details in Japanese. If the AI does not know the answer to a question, it does not
|
| 198 |
+
make up an answer and says "誠に申し訳ございませんが、その点についてはわかりかねます".
|
| 199 |
+
""".replace("\n", "")
|
| 200 |
+
|
| 201 |
+
chat_common_format = """
|
| 202 |
+
===
|
| 203 |
+
Question: {query}
|
| 204 |
+
===
|
| 205 |
+
Conversation History:
|
| 206 |
+
{chat_history}
|
| 207 |
+
===
|
| 208 |
+
日本語の回答:"""
|
| 209 |
+
|
| 210 |
+
chat_template_std = f"{sys_chat_message}{chat_common_format}"
|
| 211 |
+
chat_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{chat_common_format}[/INST]"
|
| 212 |
+
|
| 213 |
+
# --------------------------------------
|
| 214 |
+
# QA Chain Template
|
| 215 |
+
# --------------------------------------
|
| 216 |
+
# Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
|
| 217 |
+
sys_qa_message = """
|
| 218 |
+
You are an AI concierge who carefully answers questions from customers based on references.
|
| 219 |
+
You understand what the customer wants to know from the "Conversation History" and "Question",
|
| 220 |
+
and give a specific answer in Japanese using sentences extracted from the following references.
|
| 221 |
+
If you do not know the answer, do not make up an answer and reply,
|
| 222 |
+
"誠に申し訳ございませんが、その点についてはわかりかねます".
|
| 223 |
+
""".replace("\n", "")
|
| 224 |
+
|
| 225 |
+
qa_common_format = """
|
| 226 |
+
===
|
| 227 |
+
Question:
|
| 228 |
+
{query}
|
| 229 |
+
===
|
| 230 |
+
References:
|
| 231 |
+
{context}
|
| 232 |
+
===
|
| 233 |
+
Conversation History:
|
| 234 |
+
{chat_history}
|
| 235 |
+
===
|
| 236 |
+
日本語の回答:"""
|
| 237 |
+
|
| 238 |
+
qa_template_std = f"{sys_qa_message}{qa_common_format}"
|
| 239 |
+
qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
|
| 240 |
+
|
| 241 |
+
# --------------------------------------
|
| 242 |
+
# ConversationSummaryBufferMemoryの要約プロンプト
|
| 243 |
+
# ソース → https://github.com/langchain-ai/langchain/blob/894c272a562471aadc1eb48e4a2992923533dea0/langchain/memory/prompt.py#L26-L49
|
| 244 |
+
# --------------------------------------
|
| 245 |
+
# Tokens: OpenAI 212/ Llama 214 <- In Japanese: Tokens: OpenAI 397/ Llama 297
|
| 246 |
+
conversation_summary_template = """
|
| 247 |
+
Using the example as a guide, compose a summary in English that gives an overview of the conversation by summarizing the "current summary" and the "new conversation".
|
| 248 |
+
===
|
| 249 |
+
Example
|
| 250 |
+
[Current Summary] Customer asks AI what it thinks about Artificial Intelligence, AI says Artificial Intelligence is a good tool.
|
| 251 |
+
|
| 252 |
+
[New Conversation]
|
| 253 |
+
Human: なぜ人工知能が良いツールだと思いますか?
|
| 254 |
+
AI: 人工知能は「人間の可能性を最大限に引き出すことを助ける」からです。
|
| 255 |
+
|
| 256 |
+
[New Summary] Customer asks what you think about Artificial Intelligence, and AI responds that it is a good force that helps humans reach their full potential.
|
| 257 |
+
===
|
| 258 |
+
[Current Summary] {summary}
|
| 259 |
+
|
| 260 |
+
[New Conversation]
|
| 261 |
+
{new_lines}
|
| 262 |
+
|
| 263 |
+
[New Summary]
|
| 264 |
+
""".strip()
|
| 265 |
+
|
| 266 |
+
# モデル読み込み
|
| 267 |
+
def load_models(
|
| 268 |
+
ss: SessionState,
|
| 269 |
+
model_id: str,
|
| 270 |
+
embedding_id: str,
|
| 271 |
+
openai_api_key: str,
|
| 272 |
+
load_in_8bit: bool,
|
| 273 |
+
verbose: bool,
|
| 274 |
+
temperature: float,
|
| 275 |
+
min_length: int,
|
| 276 |
+
max_new_tokens: int,
|
| 277 |
+
top_k: int,
|
| 278 |
+
top_p: float,
|
| 279 |
+
repetition_penalty: float,
|
| 280 |
+
num_return_sequences: int,
|
| 281 |
+
) -> (SessionState, str):
|
| 282 |
+
|
| 283 |
+
# --------------------------------------
|
| 284 |
+
# OpenAI API KEYの確認
|
| 285 |
+
# --------------------------------------
|
| 286 |
+
if (model_id == "gpt-3.5-turbo" or embedding_id == "text-embedding-ada-002"):
|
| 287 |
+
# 前処理
|
| 288 |
+
if not os.environ["OPENAI_API_KEY"]:
|
| 289 |
+
status_message = "❌ OpenAI API KEY を設定してください"
|
| 290 |
+
return ss, status_message
|
| 291 |
+
|
| 292 |
+
# --------------------------------------
|
| 293 |
+
# LLMの設定
|
| 294 |
+
# --------------------------------------
|
| 295 |
+
# OpenAI Model
|
| 296 |
+
if model_id == "gpt-3.5-turbo":
|
| 297 |
+
ss.clear_memory(llm=True, db=True)
|
| 298 |
+
ss.llm = ChatOpenAI(
|
| 299 |
+
model_name = model_id,
|
| 300 |
+
temperature = temperature,
|
| 301 |
+
verbose = verbose,
|
| 302 |
+
max_tokens = max_new_tokens,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# Hugging Face GPT Model
|
| 306 |
+
else:
|
| 307 |
+
ss.clear_memory(llm=True, db=True)
|
| 308 |
+
|
| 309 |
+
if model_id == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
| 310 |
+
ss.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
| 311 |
+
else:
|
| 312 |
+
ss.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 313 |
+
|
| 314 |
+
ss.model = AutoModelForCausalLM.from_pretrained(
|
| 315 |
+
model_id,
|
| 316 |
+
load_in_8bit = load_in_8bit,
|
| 317 |
+
torch_dtype = torch.float16,
|
| 318 |
+
device_map = "auto",
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
ss.pipe = pipeline(
|
| 322 |
+
"text-generation",
|
| 323 |
+
model = ss.model,
|
| 324 |
+
tokenizer = ss.tokenizer,
|
| 325 |
+
min_length = min_length,
|
| 326 |
+
max_new_tokens = max_new_tokens,
|
| 327 |
+
do_sample = True,
|
| 328 |
+
top_k = top_k,
|
| 329 |
+
top_p = top_p,
|
| 330 |
+
repetition_penalty = repetition_penalty,
|
| 331 |
+
num_return_sequences = num_return_sequences,
|
| 332 |
+
temperature = temperature,
|
| 333 |
+
)
|
| 334 |
+
ss.llm = HuggingFacePipeline(pipeline=ss.pipe)
|
| 335 |
+
|
| 336 |
+
# --------------------------------------
|
| 337 |
+
# 埋め込みモデルの設定
|
| 338 |
+
# --------------------------------------
|
| 339 |
+
if ss.current_embedding == embedding_id:
|
| 340 |
+
return
|
| 341 |
+
|
| 342 |
+
# Reset embeddings and vectordb
|
| 343 |
+
ss.clear_memory(embd=True, db=True)
|
| 344 |
+
|
| 345 |
+
if embedding_id == "None":
|
| 346 |
+
pass
|
| 347 |
+
|
| 348 |
+
# OpenAI
|
| 349 |
+
elif embedding_id == "text-embedding-ada-002":
|
| 350 |
+
ss.embeddings = OpenAIEmbeddings()
|
| 351 |
+
|
| 352 |
+
# Hugging Face
|
| 353 |
+
else:
|
| 354 |
+
ss.embeddings = HuggingFaceEmbeddings(model_name=embedding_id)
|
| 355 |
+
|
| 356 |
+
# --------------------------------------
|
| 357 |
+
# 現在のモデル名を SessionStateオブジェクトに保存
|
| 358 |
+
#---------------------------------------
|
| 359 |
+
ss.current_model = model_id
|
| 360 |
+
ss.current_embedding = embedding_id
|
| 361 |
+
|
| 362 |
+
# Status Message
|
| 363 |
+
status_message = "✅ LLM: " + ss.current_model + ", embeddings: " + ss.current_embedding
|
| 364 |
+
|
| 365 |
+
return ss, status_message
|
| 366 |
+
|
| 367 |
+
def conversation_prep(ss: SessionState) -> SessionState:
|
| 368 |
+
if ss.conversation_chain is None:
|
| 369 |
+
|
| 370 |
+
human_prefix = "Human: "
|
| 371 |
+
ai_prefix = "AI: "
|
| 372 |
+
chat_template = chat_template_std
|
| 373 |
+
|
| 374 |
+
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
| 375 |
+
# Rinnaモデル向けの設定(改行コード修正、メモリ用prefix (公式ページ参照)
|
| 376 |
+
chat_template = chat_template.replace("\n", "<NL>")
|
| 377 |
+
human_prefix = "ユーザー: "
|
| 378 |
+
ai_prefix = "システム: "
|
| 379 |
+
|
| 380 |
+
elif ss.current_model.startswith("elyza/ELYZA-japanese-Llama-2-7b"):
|
| 381 |
+
chat_template = chat_template_llama2
|
| 382 |
+
|
| 383 |
+
chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
|
| 384 |
+
|
| 385 |
+
if ss.memory is None:
|
| 386 |
+
conversation_summary_prompt = PromptTemplate(input_variables=['summary', 'new_lines'], template=conversation_summary_template)
|
| 387 |
+
ss.memory = ConversationSummaryBufferMemory(
|
| 388 |
+
llm = ss.llm,
|
| 389 |
+
memory_key = "chat_history",
|
| 390 |
+
input_key = "query",
|
| 391 |
+
output_key = "output_text",
|
| 392 |
+
return_messages = True,
|
| 393 |
+
human_prefix = human_prefix,
|
| 394 |
+
ai_prefix = ai_prefix,
|
| 395 |
+
max_token_limit = 512,
|
| 396 |
+
prompt = conversation_summary_prompt,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
ss.conversation_chain = ConversationChain(
|
| 400 |
+
llm=ss.llm,
|
| 401 |
+
prompt = chat_prompt,
|
| 402 |
+
memory = ss.memory
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
return ss
|
| 406 |
+
|
| 407 |
+
def initialize_db(ss: SessionState) -> SessionState:
|
| 408 |
+
|
| 409 |
+
# client = chromadb.PersistentClient(path="./db")
|
| 410 |
+
ss.db = Chroma(
|
| 411 |
+
collection_name = "user_reference",
|
| 412 |
+
embedding_function = ss.embeddings,
|
| 413 |
+
# client = client
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
return ss
|
| 417 |
+
|
| 418 |
+
def embedding_process(ss: SessionState, ref_documents: Document) -> SessionState:
|
| 419 |
+
|
| 420 |
+
# --------------------------------------
|
| 421 |
+
# 文章構成と不要な文字列の削除
|
| 422 |
+
# --------------------------------------
|
| 423 |
+
for i in range(len(ref_documents)):
|
| 424 |
+
content = ref_documents[i].page_content.strip()
|
| 425 |
+
|
| 426 |
+
# --------------------------------------
|
| 427 |
+
# PDFの場合は読み取りエラー対策で文書修正を強めに実施
|
| 428 |
+
# --------------------------------------
|
| 429 |
+
if ".pdf" in ref_documents[i].metadata['source']:
|
| 430 |
+
pdf_replacement_sets = [
|
| 431 |
+
('\n ', '**PLACEHOLDER+SPACE**'),
|
| 432 |
+
('\n\u3000', '**PLACEHOLDER+SPACE**'),
|
| 433 |
+
('.\n', '。**PLACEHOLDER**'),
|
| 434 |
+
(',\n', '。**PLACEHOLDER**'),
|
| 435 |
+
('?\n', '。**PLACEHOLDER**'),
|
| 436 |
+
('!\n', '。**PLACEHOLDER**'),
|
| 437 |
+
('!\n', '。**PLACEHOLDER**'),
|
| 438 |
+
('。\n', '。**PLACEHOLDER**'),
|
| 439 |
+
('!\n', '!**PLACEHOLDER**'),
|
| 440 |
+
(')\n', '!**PLACEHOLDER**'),
|
| 441 |
+
(']\n', '!**PLACEHOLDER**'),
|
| 442 |
+
('?\n', '?**PLACEHOLDER**'),
|
| 443 |
+
(')\n', '?**PLACEHOLDER**'),
|
| 444 |
+
('】\n', '?**PLACEHOLDER**'),
|
| 445 |
+
]
|
| 446 |
+
for original, replacement in pdf_replacement_sets:
|
| 447 |
+
content = content.replace(original, replacement)
|
| 448 |
+
content = content.replace(" ", "")
|
| 449 |
+
# --------------------------------------
|
| 450 |
+
|
| 451 |
+
# 不要文字列・空白の削除
|
| 452 |
+
remove_texts = ["\n", "\r", " "]
|
| 453 |
+
for remove_text in remove_texts:
|
| 454 |
+
content = content.replace(remove_text, "")
|
| 455 |
+
|
| 456 |
+
# タブや連続空白をシングルスペースに変換
|
| 457 |
+
replace_texts = ["\t", "\u3000"]
|
| 458 |
+
for replace_text in replace_texts:
|
| 459 |
+
content = content.replace(replace_text, " ")
|
| 460 |
+
|
| 461 |
+
# PDFの正当な改行をもとに戻す。
|
| 462 |
+
if ".pdf" in ref_documents[i].metadata['source']:
|
| 463 |
+
content = content.replace('**PLACEHOLDER**', '\n').replace('**PLACEHOLDER+SPACE**', '\n ')
|
| 464 |
+
|
| 465 |
+
ref_documents[i].page_content = content
|
| 466 |
+
|
| 467 |
+
# --------------------------------------
|
| 468 |
+
# チャンクに分割
|
| 469 |
+
texts = text_splitter.split_documents(ref_documents)
|
| 470 |
+
|
| 471 |
+
# --------------------------------------
|
| 472 |
+
# multi-e5 モデルの学習環境に合わせて文言を追加
|
| 473 |
+
# https://hironsan.hatenablog.com/entry/2023/07/05/073150
|
| 474 |
+
# --------------------------------------
|
| 475 |
+
if ss.current_embedding == "intfloat/multilingual-e5-large":
|
| 476 |
+
for i in range(len(texts)):
|
| 477 |
+
texts[i].page_content = "passage:" + texts[i].page_content
|
| 478 |
+
|
| 479 |
+
# vectordb の初期化
|
| 480 |
+
if ss.db is None:
|
| 481 |
+
ss = initialize_db(ss)
|
| 482 |
+
|
| 483 |
+
# db に埋め込み
|
| 484 |
+
# ss.db = Chroma.from_documents(texts, ss.embeddings)
|
| 485 |
+
ss.db.add_documents(documents=texts, embedding=ss.embeddings)
|
| 486 |
+
|
| 487 |
+
# --------------------------------------
|
| 488 |
+
# QAチェーンの設定
|
| 489 |
+
# --------------------------------------
|
| 490 |
+
if ss.qa_chain is None:
|
| 491 |
+
|
| 492 |
+
# QAメモリ
|
| 493 |
+
human_prefix = "Human: "
|
| 494 |
+
ai_prefix = "AI: "
|
| 495 |
+
qa_template = qa_template_std
|
| 496 |
+
|
| 497 |
+
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
| 498 |
+
# Rinnaモデル向けの設定(改行コード修正、メモリ用prefix (公式ページ参照)
|
| 499 |
+
qa_template = qa_template.replace("\n", "<NL>")
|
| 500 |
+
human_prefix = "ユーザー: "
|
| 501 |
+
ai_prefix = "システム: "
|
| 502 |
+
|
| 503 |
+
elif ss.current_model.startswith("elyza/ELYZA-japanese-Llama-2-7b"):
|
| 504 |
+
qa_template = qa_template_llama2
|
| 505 |
+
|
| 506 |
+
qa_prompt = PromptTemplate(input_variables=['context', 'query', 'chat_history'], template=qa_template)
|
| 507 |
+
|
| 508 |
+
if ss.memory is None:
|
| 509 |
+
conversation_summary_prompt = PromptTemplate(input_variables=['summary', 'new_lines'], template=conversation_summary_template)
|
| 510 |
+
ss.memory = ConversationSummaryBufferMemory(
|
| 511 |
+
llm = ss.llm,
|
| 512 |
+
memory_key = "chat_history",
|
| 513 |
+
input_key = "query",
|
| 514 |
+
output_key = "output_text",
|
| 515 |
+
return_messages = True,
|
| 516 |
+
human_prefix = human_prefix,
|
| 517 |
+
ai_prefix = ai_prefix,
|
| 518 |
+
max_token_limit = 512,
|
| 519 |
+
prompt = conversation_summary_prompt,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt)
|
| 523 |
+
|
| 524 |
+
return ss
|
| 525 |
+
|
| 526 |
+
def embed_ref(ss: SessionState, urls: str, fileobj: list, header_lim: int, footer_lim: int) -> (SessionState, str):
|
| 527 |
+
|
| 528 |
+
url_flag = "-"
|
| 529 |
+
pdf_flag = "-"
|
| 530 |
+
|
| 531 |
+
# --------------------------------------
|
| 532 |
+
# URLの読み込みとvectordb登録
|
| 533 |
+
# --------------------------------------
|
| 534 |
+
|
| 535 |
+
# URLリストの前処理(リスト化、重複削除、非URL排除)
|
| 536 |
+
urls = list({url for url in urls.split("\n") if url and "://" in url})
|
| 537 |
+
|
| 538 |
+
if urls:
|
| 539 |
+
# 登録済みURL(ss.embedded_urls)との重複を排除。登録済みリストに登録
|
| 540 |
+
urls = [url for url in urls if url not in ss.embedded_urls]
|
| 541 |
+
ss.embedded_urls.extend(urls)
|
| 542 |
+
|
| 543 |
+
# ウェブページの読み込み
|
| 544 |
+
loader = SeleniumURLLoader(urls=urls)
|
| 545 |
+
ref_documents = loader.load()
|
| 546 |
+
|
| 547 |
+
# 埋め込み処理の実行
|
| 548 |
+
ss = embedding_process(ss, ref_documents)
|
| 549 |
+
|
| 550 |
+
url_flag = "✅ 登録済"
|
| 551 |
+
|
| 552 |
+
# --------------------------------------
|
| 553 |
+
# PDFのヘッダーとフッターを除去してvectordb登録
|
| 554 |
+
# https://pypdf.readthedocs.io/en/stable/user/extract-text.html
|
| 555 |
+
# --------------------------------------
|
| 556 |
+
|
| 557 |
+
if fileobj is None:
|
| 558 |
+
pass
|
| 559 |
+
|
| 560 |
+
else:
|
| 561 |
+
# ファイル名リストを取得
|
| 562 |
+
pdf_paths = []
|
| 563 |
+
for path in fileobj:
|
| 564 |
+
pdf_paths.append(path.name)
|
| 565 |
+
|
| 566 |
+
# リストの初期化
|
| 567 |
+
ref_documents = []
|
| 568 |
+
|
| 569 |
+
# 各PDFファイルを読み込み
|
| 570 |
+
for pdf_path in pdf_paths:
|
| 571 |
+
pdf = PdfReader(pdf_path)
|
| 572 |
+
body = []
|
| 573 |
+
|
| 574 |
+
def visitor_body(text, cm, tm, font_dict, font_size):
|
| 575 |
+
y = tm[5]
|
| 576 |
+
if y > footer_lim and y < header_lim: # y座標がヘッダーとフッターの間にあるかどうかを確認
|
| 577 |
+
parts.append(text)
|
| 578 |
+
|
| 579 |
+
for page in pdf.pages:
|
| 580 |
+
parts = []
|
| 581 |
+
page.extract_text(visitor_text=visitor_body)
|
| 582 |
+
body.append("".join(parts))
|
| 583 |
+
|
| 584 |
+
body = "\n".join(body)
|
| 585 |
+
|
| 586 |
+
# パスからファイル名のみを取得
|
| 587 |
+
filename = os.path.basename(pdf_path)
|
| 588 |
+
# 取得テキスト → LangChain ドキュメント変換
|
| 589 |
+
ref_documents.append(Document(page_content=body, metadata={"source": filename}))
|
| 590 |
+
|
| 591 |
+
# 埋め込み処理の実行
|
| 592 |
+
ss = embedding_process(ss, ref_documents)
|
| 593 |
+
|
| 594 |
+
pdf_flag = "✅ 登録済"
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
langchain.debug=True
|
| 598 |
+
|
| 599 |
+
status_message = "URL: " + url_flag + " / PDF: " + pdf_flag
|
| 600 |
+
return ss, status_message
|
| 601 |
+
|
| 602 |
+
def clear_db(ss: SessionState) -> (SessionState, str):
|
| 603 |
+
try:
|
| 604 |
+
ss.db.delete_collection()
|
| 605 |
+
status_message = "✅ 参照データを削除しました。"
|
| 606 |
+
|
| 607 |
+
except NameError:
|
| 608 |
+
status_message = "❌ 参照データが登録されていません。"
|
| 609 |
+
|
| 610 |
+
return ss, status_message
|
| 611 |
+
|
| 612 |
+
# ----------------------------------------------------------------------------
|
| 613 |
+
# query入力 ▶ [def user] ▶ [ def bot ] ▶ [def show_response] ▶ チャットボット画面
|
| 614 |
+
# ⬇ ⬇ ⬆
|
| 615 |
+
# チャットボット画面 [qa_predict / conversation_predict]
|
| 616 |
+
# ----------------------------------------------------------------------------
|
| 617 |
+
|
| 618 |
+
def user(ss: SessionState, query) -> (SessionState, list):
|
| 619 |
+
# 会話履歴が一定数を超えた場合は、最初の履歴を削除する
|
| 620 |
+
if len(ss.dialogue) > 10:
|
| 621 |
+
ss.dialogue.pop(0)
|
| 622 |
+
|
| 623 |
+
ss.dialogue = ss.dialogue + [(query, None)] # 会話履歴(None はボットの回答欄=空欄)
|
| 624 |
+
chat_history = ss.dialogue
|
| 625 |
+
|
| 626 |
+
# チャット画面=chat_history
|
| 627 |
+
return ss, chat_history
|
| 628 |
+
|
| 629 |
+
def bot(ss: SessionState, query, qa_flag) -> (SessionState, str):
|
| 630 |
+
if qa_flag is True:
|
| 631 |
+
ss = qa_predict(ss, query) # LLMで回答を生成
|
| 632 |
+
|
| 633 |
+
else:
|
| 634 |
+
ss = conversation_prep(ss)
|
| 635 |
+
ss = chat_predict(ss, query)
|
| 636 |
+
|
| 637 |
+
return ss, "" # ssとquery欄(空欄)
|
| 638 |
+
|
| 639 |
+
def chat_predict(ss: SessionState, query) -> SessionState:
|
| 640 |
+
response = ss.conversation_chain.predict(input=query)
|
| 641 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
| 642 |
+
return ss
|
| 643 |
+
|
| 644 |
+
def qa_predict(ss: SessionState, query) -> SessionState:
|
| 645 |
+
|
| 646 |
+
# Rinnaモデル向けの設定(クエリの改行コード修正)
|
| 647 |
+
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
| 648 |
+
query = query.strip().replace("\n", "<NL>")
|
| 649 |
+
else:
|
| 650 |
+
query = query.strip()
|
| 651 |
+
|
| 652 |
+
# multilingual-e5向けのクエリ文言prefix
|
| 653 |
+
if ss.current_embedding == "intfloat/multilingual-e5-large":
|
| 654 |
+
db_query_str = "query: " + query
|
| 655 |
+
else:
|
| 656 |
+
db_query_str = query
|
| 657 |
+
|
| 658 |
+
# DBから関連文書と出典を抽出
|
| 659 |
+
docs = ss.db.similarity_search(db_query_str, k=2)
|
| 660 |
+
sources= "\n\n[Sources]\n" + '\n - '.join(list(set(doc.metadata['source'] for doc in docs if 'source' in doc.metadata)))
|
| 661 |
+
|
| 662 |
+
# Rinnaモデル向けの設定(抽出文書の改行コード修正)
|
| 663 |
+
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
| 664 |
+
for i in range(len(docs)):
|
| 665 |
+
docs[i].page_content = docs[i].page_content.strip().replace("\n", "<NL>")
|
| 666 |
+
|
| 667 |
+
# 回答の生成(最大3回の試行)
|
| 668 |
+
for _ in range(3):
|
| 669 |
+
result = ss.qa_chain({"input_documents": docs, "query": query})
|
| 670 |
+
result["output_text"] = result["output_text"].replace("<NL>", "\n").strip("...").strip("回答:").strip()
|
| 671 |
+
|
| 672 |
+
# result["output_text"]が空欄でない場合、メモリーを更新して返す
|
| 673 |
+
if result["output_text"] != "":
|
| 674 |
+
response = result["output_text"] + sources
|
| 675 |
+
ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-1] # 最後の会話を削除
|
| 676 |
+
ss.memory.chat_memory.add_user_message(query)
|
| 677 |
+
ss.memory.chat_memory.add_ai_message(response)
|
| 678 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
| 679 |
+
return ss
|
| 680 |
+
else:
|
| 681 |
+
# 空欄の場合は直近の履歴を削除してやり直し
|
| 682 |
+
ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-1]
|
| 683 |
+
|
| 684 |
+
# 3回の試行後も空欄の場合
|
| 685 |
+
response = "3回試行しましたが、情報製生成できませんでした。"
|
| 686 |
+
if sources != "":
|
| 687 |
+
response += "参考文献の抽出には成功していますので、言語モデルを変えてお試しください。"
|
| 688 |
+
|
| 689 |
+
# ユーザーメッセージと AI メッセージの追加
|
| 690 |
+
ss.memory.chat_memory.add_user_message(query.replace("<NL>", "\n"))
|
| 691 |
+
ss.memory.chat_memory.add_ai_message(response)
|
| 692 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response) # 会話履歴
|
| 693 |
+
return ss
|
| 694 |
+
|
| 695 |
+
# 回答を1文字ずつチャット画面に表示する
|
| 696 |
+
def show_response(ss: SessionState) -> str:
|
| 697 |
+
# chat_history = ss.load_chat_history() # メモリから会話履歴をリスト型で取得
|
| 698 |
+
# response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
|
| 699 |
+
# chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
|
| 700 |
+
|
| 701 |
+
chat_history = [list(item) for item in ss.dialogue] # タプルをリストに変換して、メモリから会話履歴を取得
|
| 702 |
+
response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
|
| 703 |
+
chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
|
| 704 |
+
|
| 705 |
+
for character in response:
|
| 706 |
+
chat_history[-1][1] += character
|
| 707 |
+
time.sleep(0.05)
|
| 708 |
+
yield chat_history
|
| 709 |
+
|
| 710 |
+
with gr.Blocks() as demo:
|
| 711 |
+
|
| 712 |
+
# ユーザ別セッションメモリのインスタンス化(リロードでリセット)
|
| 713 |
+
ss = gr.State(SessionState())
|
| 714 |
+
|
| 715 |
+
# --------------------------------------
|
| 716 |
+
# API KEY をセット/クリアする関数
|
| 717 |
+
# --------------------------------------
|
| 718 |
+
def openai_api_setfn(openai_api_key) -> str:
|
| 719 |
+
if not openai_api_key or not openai_api_key.startswith("sk-") or len(openai_api_key) < 50:
|
| 720 |
+
os.environ["OPENAI_API_KEY"] = ""
|
| 721 |
+
status_message = "❌ 有効なAPIキーを入力してください"
|
| 722 |
+
return status_message
|
| 723 |
+
else:
|
| 724 |
+
os.environ["OPENAI_API_KEY"] = openai_api_key
|
| 725 |
+
status_message = "✅ APIキーを設定しました"
|
| 726 |
+
return status_message
|
| 727 |
+
|
| 728 |
+
def openai_api_clsfn(ss) -> (str, str):
|
| 729 |
+
openai_api_key = ""
|
| 730 |
+
os.environ["OPENAI_API_KEY"] = ""
|
| 731 |
+
status_message = "✅ APIキーの削除が完了しました"
|
| 732 |
+
return status_message, ""
|
| 733 |
+
|
| 734 |
+
# --------------------------------------
|
| 735 |
+
# 回答の継続ボタン
|
| 736 |
+
# --------------------------------------
|
| 737 |
+
def continue_pred():
|
| 738 |
+
query = "回答を続けてください"
|
| 739 |
+
return query
|
| 740 |
+
|
| 741 |
+
with gr.Tabs():
|
| 742 |
+
# --------------------------------------
|
| 743 |
+
# Setting Tab
|
| 744 |
+
# --------------------------------------
|
| 745 |
+
with gr.TabItem("1. LLM設定"):
|
| 746 |
+
with gr.Row():
|
| 747 |
+
model_id = gr.Dropdown(
|
| 748 |
+
choices=[
|
| 749 |
+
'elyza/ELYZA-japanese-Llama-2-7b-fast-instruct',
|
| 750 |
+
'rinna/bilingual-gpt-neox-4b-instruction-sft',
|
| 751 |
+
'gpt-3.5-turbo',
|
| 752 |
+
],
|
| 753 |
+
value="elyza/ELYZA-japanese-Llama-2-7b-fast-instruct",
|
| 754 |
+
label='LLM model',
|
| 755 |
+
interactive=True,
|
| 756 |
+
)
|
| 757 |
+
with gr.Row():
|
| 758 |
+
embedding_id = gr.Dropdown(
|
| 759 |
+
choices=[
|
| 760 |
+
'intfloat/multilingual-e5-large',
|
| 761 |
+
'sonoisa/sentence-bert-base-ja-mean-tokens-v2',
|
| 762 |
+
'oshizo/sbert-jsnli-luke-japanese-base-lite',
|
| 763 |
+
'text-embedding-ada-002',
|
| 764 |
+
"None"
|
| 765 |
+
],
|
| 766 |
+
value="sonoisa/sentence-bert-base-ja-mean-tokens-v2",
|
| 767 |
+
label = 'Embedding model',
|
| 768 |
+
interactive=True,
|
| 769 |
+
)
|
| 770 |
+
with gr.Row():
|
| 771 |
+
with gr.Column(scale=19):
|
| 772 |
+
openai_api_key = gr.Textbox(label="OpenAI API Key (Optional)", interactive=True, type="password", value="", placeholder="Your OpenAI API Key for OpenAI models.", max_lines=1)
|
| 773 |
+
with gr.Column(scale=1):
|
| 774 |
+
openai_api_set = gr.Button(value="Set API KEY", size="sm")
|
| 775 |
+
openai_api_cls = gr.Button(value="Delete API KEY", size="sm")
|
| 776 |
+
|
| 777 |
+
# 詳細設定(折りたたみ)
|
| 778 |
+
with gr.Accordion(label="Advanced Setting", open=False):
|
| 779 |
+
with gr.Row():
|
| 780 |
+
with gr.Column():
|
| 781 |
+
load_in_8bit = gr.Checkbox(label="8bit Quantize (HF)", value=True, interactive=True)
|
| 782 |
+
verbose = gr.Checkbox(label="Verbose (OpenAI, HF)", value=True, interactive=False)
|
| 783 |
+
with gr.Column():
|
| 784 |
+
temperature = gr.Slider(label='Temperature (OpenAI, HF)', minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True)
|
| 785 |
+
with gr.Column():
|
| 786 |
+
min_length = gr.Slider(label="min_length (HF)", minimum=1, maximum=100, step=1, value=10, interactive=True)
|
| 787 |
+
with gr.Column():
|
| 788 |
+
max_new_tokens = gr.Slider(label="max_tokens(OpenAI), max_new_tokens(HF)", minimum=1, maximum=1024, step=1, value=256, interactive=True)
|
| 789 |
+
with gr.Column():
|
| 790 |
+
top_k = gr.Slider(label='top_k (HF)', minimum=1, maximum=100, step=1, value=40, interactive=True)
|
| 791 |
+
with gr.Column():
|
| 792 |
+
top_p = gr.Slider(label='top_p (HF)', minimum=0.01, maximum=0.99, step=0.01, value=0.92, interactive=True)
|
| 793 |
+
with gr.Column():
|
| 794 |
+
repetition_penalty = gr.Slider(label='repetition_penalty (HF)', minimum=0.5, maximum=2, step=0.1, value=1.2, interactive=True)
|
| 795 |
+
with gr.Column():
|
| 796 |
+
num_return_sequences = gr.Slider(label='num_return_sequences (HF)', minimum=1, maximum=20, step=1, value=3, interactive=True)
|
| 797 |
+
|
| 798 |
+
with gr.Row():
|
| 799 |
+
with gr.Column(scale=2):
|
| 800 |
+
config_btn = gr.Button(value="Configure")
|
| 801 |
+
with gr.Column(scale=13):
|
| 802 |
+
status_cfg = gr.Textbox(show_label=False, interactive=False, value="モデルを設定してください", container=False, max_lines=1)
|
| 803 |
+
|
| 804 |
+
# ボタン等のアクション設定
|
| 805 |
+
openai_api_set.click(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full")
|
| 806 |
+
openai_api_cls.click(openai_api_clsfn, inputs=[openai_api_key], outputs=[status_cfg, openai_api_key], show_progress="full")
|
| 807 |
+
openai_api_key.submit(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full")
|
| 808 |
+
config_btn.click(
|
| 809 |
+
fn = load_models,
|
| 810 |
+
inputs = [ss, model_id, embedding_id, openai_api_key, load_in_8bit, verbose, temperature,
|
| 811 |
+
min_length, max_new_tokens, top_k, top_p, repetition_penalty, num_return_sequences],
|
| 812 |
+
outputs = [ss, status_cfg],
|
| 813 |
+
queue = True,
|
| 814 |
+
show_progress = "full"
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
# --------------------------------------
|
| 818 |
+
# Reference Tab
|
| 819 |
+
# --------------------------------------
|
| 820 |
+
with gr.TabItem("2. References"):
|
| 821 |
+
urls = gr.TextArea(
|
| 822 |
+
max_lines = 60,
|
| 823 |
+
show_label=False,
|
| 824 |
+
info = "List any reference URLs for Q&A retrieval.",
|
| 825 |
+
placeholder = "https://blog.kikagaku.co.jp/deep-learning-transformer\nhttps://note.com/elyza/n/na405acaca130",
|
| 826 |
+
interactive=True,
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
with gr.Row():
|
| 830 |
+
pdf_paths = gr.File(label="PDFs", height=150, min_width=60, scale=7, file_types=[".pdf"], file_count="multiple", interactive=True)
|
| 831 |
+
header_lim = gr.Number(label="Header (pt)", step=1, value=792, precision=0, min_width=70, scale=1, interactive=True)
|
| 832 |
+
footer_lim = gr.Number(label="Footer (pt)", step=1, value=0, precision=0, min_width=70, scale=1, interactive=True)
|
| 833 |
+
pdf_ref = gr.Textbox(show_label=False, value="A4 Size:\n(下)0-792pt(上)\n *28.35pt/cm", container=False, scale=1, interactive=False)
|
| 834 |
+
|
| 835 |
+
with gr.Row():
|
| 836 |
+
ref_set_btn = gr.Button(value="コンテンツ登録", scale=1)
|
| 837 |
+
ref_clear_btn = gr.Button(value="登録データ削除", scale=1)
|
| 838 |
+
status_ref = gr.Textbox(show_label=False, interactive=False, value="参照データ未登録", container=False, max_lines=1, scale=18)
|
| 839 |
+
|
| 840 |
+
ref_set_btn.click(fn=embed_ref, inputs=[ss, urls, pdf_paths, header_lim, footer_lim], outputs=[ss, status_ref], queue=True, show_progress="full")
|
| 841 |
+
ref_clear_btn.click(fn=clear_db, inputs=[ss], outputs=[ss, status_ref], show_progress="full")
|
| 842 |
+
|
| 843 |
+
# --------------------------------------
|
| 844 |
+
# Chatbot Tab
|
| 845 |
+
# --------------------------------------
|
| 846 |
+
with gr.TabItem("3. Q&A Chat"):
|
| 847 |
+
chat_history = gr.Chatbot([], elem_id="chatbot").style(height=600, color_map=('green', 'gray'))
|
| 848 |
+
with gr.Row():
|
| 849 |
+
with gr.Column(scale=95):
|
| 850 |
+
query = gr.Textbox(
|
| 851 |
+
show_label=False,
|
| 852 |
+
placeholder="Send a message with [Shift]+[Enter] key.",
|
| 853 |
+
lines=4,
|
| 854 |
+
container=False,
|
| 855 |
+
autofocus=True,
|
| 856 |
+
interactive=True,
|
| 857 |
+
)
|
| 858 |
+
with gr.Column(scale=5):
|
| 859 |
+
qa_flag = gr.Checkbox(label="QA mode", value=True, min_width=60, interactive=True)
|
| 860 |
+
query_send_btn = gr.Button(value="▶")
|
| 861 |
+
|
| 862 |
+
# gr.Examples(["機械学習について説明してください"], inputs=[query])
|
| 863 |
+
query.submit(user, [ss, query], [ss, chat_history]).then(bot, [ss, query, qa_flag], [ss, query]).then(show_response, [ss], [chat_history])
|
| 864 |
+
query_send_btn.click(user, [ss, query], [ss, chat_history]).then(bot, [ss, query, qa_flag], [ss, query]).then(show_response, [ss], [chat_history])
|
| 865 |
+
|
| 866 |
+
if __name__ == "__main__":
|
| 867 |
+
demo.queue(concurrency_count=5)
|
| 868 |
+
demo.launch(debug=True, inbrowser=True)
|
| 869 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
bitsandbytes
|
| 3 |
+
transformers
|
| 4 |
+
sentence_transformers
|
| 5 |
+
sentencepiece
|
| 6 |
+
accelerate
|
| 7 |
+
bitsandbytes
|
| 8 |
+
langchain
|
| 9 |
+
xformers
|
| 10 |
+
chromadb
|
| 11 |
+
gradio
|
| 12 |
+
openai
|
| 13 |
+
tiktoken
|
| 14 |
+
fugashi
|
| 15 |
+
ipadic
|
| 16 |
+
unstructured
|
| 17 |
+
selenium
|
| 18 |
+
pypdf
|