zym0216 commited on
Commit
4e088e2
·
verified ·
1 Parent(s): 594155c

Update chat.py

Browse files
Files changed (1) hide show
  1. chat.py +28 -53
chat.py CHANGED
@@ -1,63 +1,40 @@
1
  import os
2
  import openai
3
- import torch
4
-
5
- import llama_index
6
- from llama_index import ServiceContext, set_global_service_context
7
- from llama_index.prompts import PromptTemplate
8
- from llama_index.llms import HuggingFaceLLM
9
- from llama_index import VectorStoreIndex, SimpleDirectoryReader
 
 
10
 
11
 
 
 
 
 
 
12
 
13
- def messages_to_prompt(messages):
14
- prompt = ""
15
- for message in messages:
16
- if message.role == 'system':
17
- prompt += f"<|system|>\n{message.content}</s>\n"
18
- elif message.role == 'user':
19
- prompt += f"<|user|>\n{message.content}</s>\n"
20
- elif message.role == 'assistant':
21
- prompt += f"<|assistant|>\n{message.content}</s>\n"
22
 
23
- # ensure we start with a system prompt, insert blank if needed
24
- if not prompt.startswith("<|system|>\n"):
25
- prompt = "<|system|>\n</s>\n" + prompt
26
 
27
- # add final assistant prompt
28
- prompt = prompt + "<|assistant|>\n"
29
 
30
- return prompt
31
 
 
 
 
 
 
32
 
33
- class RetrievalChatbot():
34
- def __init__(self, api_key, api_base, pdf_dir, model_name):
35
- openai.api_key = api_key
36
- openai.api_base = api_base
37
- self.model_name = model_name
38
- documents = SimpleDirectoryReader(input_dir="papers_all").load_data()
39
- print("find doc")
40
- llm_zephyr = HuggingFaceLLM(
41
- model_name="HuggingFaceH4/zephyr-7b-beta",
42
- tokenizer_name="HuggingFaceH4/zephyr-7b-beta",
43
- query_wrapper_prompt=PromptTemplate("<|system|>\n</s>\n<|user|>\n{query_str}</s>\n<|assistant|>\n"),
44
- context_window=2048,
45
- max_new_tokens=128,
46
- messages_to_prompt=messages_to_prompt,
47
- device_map="auto",
48
  )
49
-
50
- print("loaded llm")
51
- service_context = ServiceContext.from_defaults(llm=llm_zephyr, chunk_size=512)
52
- # set_global_service_context(service_context)
53
- print("loaded doc")
54
- index = VectorStoreIndex.from_documents(documents, service_context=service_context)
55
- # index.storage_context.persist(persist_dir="index")
56
- print("save index")
57
- qa = index.as_query_engine(streaming=True)
58
- self.qa=qa
59
-
60
-
61
 
62
  self.prompt = (
63
  "Please answer the following question using information with the assistance of the given context.\n",
@@ -98,10 +75,8 @@ class RetrievalChatbot():
98
 
99
  for rephrase in subquestion_list:
100
  query = "".join(self.prompt).format(message=rephrase)
101
- print("query:",query)
102
- response = self.qa.query(query)
103
- print("respponse:",response)
104
- responses_list.append(response)
105
 
106
  summarize_prompt = "".join(self.summarize_prompt).format(question=message, answers=responses_list)
107
  summarized_answer = self.get_openai_response(summarize_prompt, self.model_name)
 
1
  import os
2
  import openai
3
+ import langchain
4
+ from langchain.document_loaders import PyMuPDFLoader, DirectoryLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.embeddings import OpenAIEmbeddings
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.chat_models import ChatOpenAI
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.globals import set_verbose
11
+ from langchain.llms import VLLM
12
 
13
 
14
+ class RetrievalChatbot():
15
+ def __init__(self, api_key, api_base, model_name):
16
+ openai.api_key = api_key
17
+ #openai.api_base = api_base
18
+ self.model_name = model_name
19
 
20
+ set_verbose(True)
 
 
 
 
 
 
 
 
21
 
 
 
 
22
 
 
 
23
 
24
+ embeddings = OpenAIEmbeddings(max_retries=100)
25
 
26
+ if os.path.exists("persist"):
27
+ vectordb = Chroma(persist_directory="persist", embedding_function=embeddings)
28
+ print("loaded existing database")
29
+ else:
30
+ os.mkdir("persist")
31
 
32
+ retriever = vectordb.as_retriever(
33
+ search_type="mmr",
34
+ search_kwargs={"k": 10, "fetch_k": 50}
 
 
 
 
 
 
 
 
 
 
 
 
35
  )
36
+ llm = ChatOpenAI(model_name="gpt-4-1106-preview")
37
+ self.qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
 
 
 
 
 
 
 
 
 
 
38
 
39
  self.prompt = (
40
  "Please answer the following question using information with the assistance of the given context.\n",
 
75
 
76
  for rephrase in subquestion_list:
77
  query = "".join(self.prompt).format(message=rephrase)
78
+ retrieval_response = self.qa(query)["result"]
79
+ responses_list.append(retrieval_response)
 
 
80
 
81
  summarize_prompt = "".join(self.summarize_prompt).format(question=message, answers=responses_list)
82
  summarized_answer = self.get_openai_response(summarize_prompt, self.model_name)