qdi0 commited on
Commit
d5275a7
·
1 Parent(s): b1321ac

init commit

Browse files
Files changed (3) hide show
  1. .gitignore +4 -0
  2. app.py +83 -0
  3. requirements.txt +108 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .venv
2
+ .env
3
+ .chroma
4
+ .DS_Store
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.vectorstores import Chroma
2
+ from langchain.embeddings import OpenAIEmbeddings
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain.llms import OpenAI
5
+ from langchain.chains import VectorDBQA, RetrievalQA
6
+ from langchain.chat_models import ChatOpenAI
7
+ from langchain.document_loaders import TextLoader, PyPDFLoader
8
+ from langchain import PromptTemplate
9
+ from PyPDF2 import PdfFileMerger
10
+ import gradio as gr
11
+ from dotenv import load_dotenv
12
+ import openai
13
+ import glob
14
+ import os
15
+
16
+ load_dotenv()
17
+ os.environ["OPENAI_API_KEY"] = os.environ['OPENAI_API_KEY']
18
+
19
+ merge_file = 'src/retrieval_qa/pdf/merge.pdf'
20
+ if not os.path.isfile(merge_file):
21
+ pdf_file_merger = PdfFileMerger()
22
+ for file_name in glob.glob('src/retrieval_qa/pdf/*.pdf'):
23
+ pdf_file_merger.append(file_name)
24
+ pdf_file_merger.write(merge_file)
25
+ pdf_file_merger.close()
26
+
27
+ loader = PyPDFLoader(merge_file)
28
+ documents = loader.load()
29
+
30
+ text_splitter = RecursiveCharacterTextSplitter(
31
+ chunk_size=1000, chunk_overlap=0)
32
+ texts = text_splitter.split_documents(documents)
33
+ embeddings = OpenAIEmbeddings()
34
+ vectordb = Chroma.from_documents(texts, embeddings)
35
+
36
+ qa = RetrievalQA.from_chain_type(llm=ChatOpenAI(
37
+ model_name="gpt-3.5-turbo"), chain_type="stuff", retriever=vectordb.as_retriever())
38
+
39
+ # プロンプトの定義
40
+ template = """
41
+ あなたは再生医療・美容医学について学習したAIアシスタントです。下記の質問に具体的で医学的な回答をしてください。
42
+ 質問:{question}
43
+ 回答:
44
+ """
45
+
46
+ prompt = PromptTemplate(
47
+ input_variables=["question"],
48
+ template=template,
49
+ )
50
+
51
+
52
+ def add_text(history, text):
53
+ history = history + [(text, None)]
54
+ return history, ""
55
+
56
+
57
+ def bot(history):
58
+ query = history[-1][0]
59
+ query = prompt.format(question=query)
60
+ answer = qa.run(query)
61
+ source = qa._get_docs(query)[0]
62
+ source_sentence = source.page_content
63
+ answer_source = source_sentence + "\n"+"source:" + \
64
+ source.metadata["source"] + ", page:" + str(source.metadata["page"])
65
+ history[-1][1] = answer # + "\n\n情報ソースは以下です:\n" + answer_source
66
+ return history
67
+
68
+
69
+ with gr.Blocks() as demo:
70
+ chatbot = gr.Chatbot([], elem_id="chatbot").style(height=400)
71
+
72
+ with gr.Row():
73
+ with gr.Column(scale=0.6):
74
+ txt = gr.Textbox(
75
+ show_label=False,
76
+ placeholder="Enter text and press enter",
77
+ ).style(container=False)
78
+
79
+ txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
80
+ bot, chatbot, chatbot
81
+ )
82
+
83
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==4.2.2
5
+ anyio==3.6.2
6
+ async-timeout==4.0.2
7
+ attrs==22.2.0
8
+ backoff==2.2.1
9
+ cachetools==5.3.0
10
+ certifi==2022.12.7
11
+ charset-normalizer==3.1.0
12
+ chromadb==0.3.21
13
+ click==8.1.3
14
+ clickhouse-connect==0.5.20
15
+ contourpy==1.0.7
16
+ cycler==0.11.0
17
+ dataclasses-json==0.5.7
18
+ duckdb==0.7.1
19
+ entrypoints==0.4
20
+ fastapi==0.95.1
21
+ ffmpy==0.3.0
22
+ filelock==3.11.0
23
+ fonttools==4.39.3
24
+ frozenlist==1.3.3
25
+ fsspec==2023.4.0
26
+ gptcache==0.1.10
27
+ gradio==3.26.0
28
+ gradio_client==0.1.2
29
+ h11==0.14.0
30
+ hnswlib==0.7.0
31
+ httpcore==0.17.0
32
+ httptools==0.5.0
33
+ httpx==0.24.0
34
+ huggingface-hub==0.13.4
35
+ idna==3.4
36
+ importlib-resources==5.12.0
37
+ Jinja2==3.1.2
38
+ joblib==1.2.0
39
+ jsonschema==4.17.3
40
+ kiwisolver==1.4.4
41
+ langchain==0.0.139
42
+ linkify-it-py==2.0.0
43
+ lz4==4.3.2
44
+ markdown-it-py==2.2.0
45
+ MarkupSafe==2.1.2
46
+ marshmallow==3.19.0
47
+ marshmallow-enum==1.5.1
48
+ matplotlib==3.7.1
49
+ mdit-py-plugins==0.3.3
50
+ mdurl==0.1.2
51
+ monotonic==1.6
52
+ mpmath==1.3.0
53
+ multidict==6.0.4
54
+ mypy-extensions==1.0.0
55
+ networkx==3.1
56
+ nltk==3.8.1
57
+ numpy==1.24.2
58
+ openai==0.27.4
59
+ openapi-schema-pydantic==1.2.4
60
+ orjson==3.8.10
61
+ packaging==23.1
62
+ pandas==2.0.0
63
+ Pillow==9.5.0
64
+ posthog==2.5.0
65
+ pydantic==1.10.7
66
+ pydub==0.25.1
67
+ pyparsing==3.0.9
68
+ pypdf==3.7.1
69
+ PyPDF2==2.0.0
70
+ pyrsistent==0.19.3
71
+ python-dateutil==2.8.2
72
+ python-dotenv==1.0.0
73
+ python-multipart==0.0.6
74
+ pytz==2023.3
75
+ PyYAML==6.0
76
+ regex==2023.3.23
77
+ requests==2.28.2
78
+ scikit-learn==1.2.2
79
+ scipy==1.10.1
80
+ semantic-version==2.10.0
81
+ sentence-transformers==2.2.2
82
+ sentencepiece==0.1.98
83
+ six==1.16.0
84
+ sniffio==1.3.0
85
+ SQLAlchemy==1.4.47
86
+ starlette==0.26.1
87
+ sympy==1.11.1
88
+ tenacity==8.2.2
89
+ threadpoolctl==3.1.0
90
+ tiktoken==0.3.3
91
+ tokenizers==0.13.3
92
+ toolz==0.12.0
93
+ torch==2.0.0
94
+ torchvision==0.15.1
95
+ tqdm==4.65.0
96
+ transformers==4.28.0
97
+ typing-inspect==0.8.0
98
+ typing_extensions==4.5.0
99
+ tzdata==2023.3
100
+ uc-micro-py==1.0.1
101
+ urllib3==1.26.15
102
+ uvicorn==0.21.1
103
+ uvloop==0.17.0
104
+ watchfiles==0.19.0
105
+ websockets==11.0.1
106
+ yarl==1.8.2
107
+ zipp==3.15.0
108
+ zstandard==0.20.0