Adam Fallon commited on
Commit
6f732f7
·
1 Parent(s): 0e8964f

simplify and add description

Browse files
Files changed (1) hide show
  1. app.py +71 -16
app.py CHANGED
@@ -8,13 +8,14 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain.llms import HuggingFaceHub
9
  from tl_loaders.TrainlineTrainTimeLoader import TrainlineTrainTimeLoader
10
  from langchain.chains import RetrievalQA
 
11
  import gradio as gr
12
 
13
  from dotenv import load_dotenv
14
 
15
  load_dotenv()
16
  embedding_model = "sentence-transformers/all-mpnet-base-v2"
17
- persist_directory = "docs/chroma/openai"
18
  chunk_size = 1000
19
  chunk_overlap = 0
20
 
@@ -33,7 +34,7 @@ greet = "Ask a question in the like 'How many trains per day from Rome to Madrid
33
  newline = "\n"
34
  llm = None
35
  db = None
36
- force_reindex = True
37
  chat_history = []
38
 
39
  # Few random ones and top results from https://www.thetrainline.com/train-times
@@ -56,14 +57,33 @@ urls = {
56
  }
57
 
58
 
59
- def ask_question(message, history):
 
 
 
 
 
 
 
 
 
 
 
60
  qa = RetrievalQA.from_chain_type(
61
  llm=llm,
62
  chain_type="stuff",
63
  retriever=db.as_retriever(search_kwargs=search_kwargs),
64
  return_source_documents=True,
 
 
 
 
 
 
65
  )
66
 
 
 
67
  # result = qa(f"{newline.join(chat_history)}\n[INST]{message}[/INST]")
68
  result = qa(f"[INST]{message}[/INST]")
69
 
@@ -77,26 +97,61 @@ def ask_question(message, history):
77
  else:
78
  try:
79
  source = result["source_documents"][0].metadata["source"]
80
- return f"{answer}\n[Source]({source})"
 
81
  except:
82
  return f"{answer}"
83
 
84
 
85
  def setup_gradio():
86
- demo = gr.ChatInterface(
87
- fn=ask_question,
88
- examples=[
89
- "Trains per day from London to Edinburgh?",
90
- "When is the last train from Madrid to Barcelona?",
91
- "Train and bus operators from Rome to Madrid?",
92
- "How many changes from Barcelona to Madrid?",
93
- "Price from London to Madrid?",
94
- ],
95
- title="Trainline Q & A 🤖",
96
- description=f"Ask questions about routes. Supported routes: {', '.join(urls.values())}",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  )
 
98
 
99
- demo.launch()
100
 
101
  def load_docs():
102
  loader = TrainlineTrainTimeLoader(list(urls.keys()), urls_to_od_pair=urls)
 
8
  from langchain.llms import HuggingFaceHub
9
  from tl_loaders.TrainlineTrainTimeLoader import TrainlineTrainTimeLoader
10
  from langchain.chains import RetrievalQA
11
+ from langchain.prompts import PromptTemplate
12
  import gradio as gr
13
 
14
  from dotenv import load_dotenv
15
 
16
  load_dotenv()
17
  embedding_model = "sentence-transformers/all-mpnet-base-v2"
18
+ persist_directory = "docs/chroma/"
19
  chunk_size = 1000
20
  chunk_overlap = 0
21
 
 
34
  newline = "\n"
35
  llm = None
36
  db = None
37
+ force_reindex = False
38
  chat_history = []
39
 
40
  # Few random ones and top results from https://www.thetrainline.com/train-times
 
57
  }
58
 
59
 
60
+ def ask_question(message):
61
+ prompt = """
62
+ Use the following pieces of context to answer the question at the end.
63
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
64
+ DO NOT RAMBLE or try to infer information.
65
+ Just give the exact information requested.
66
+ Take a deep breath and work on this problem step-by-step.
67
+
68
+ {context}
69
+
70
+ Question: {question}
71
+ """
72
  qa = RetrievalQA.from_chain_type(
73
  llm=llm,
74
  chain_type="stuff",
75
  retriever=db.as_retriever(search_kwargs=search_kwargs),
76
  return_source_documents=True,
77
+ chain_type_kwargs={
78
+ "prompt": PromptTemplate(
79
+ template=prompt,
80
+ input_variables=["context", "question"],
81
+ ),
82
+ },
83
  )
84
 
85
+ print(qa.combine_documents_chain.llm_chain.prompt.template)
86
+
87
  # result = qa(f"{newline.join(chat_history)}\n[INST]{message}[/INST]")
88
  result = qa(f"[INST]{message}[/INST]")
89
 
 
97
  else:
98
  try:
99
  source = result["source_documents"][0].metadata["source"]
100
+ doc = result["source_documents"][0].page_content
101
+ return f"{answer}\n---\nInformation found in text: {doc}\nSource: {source}"
102
  except:
103
  return f"{answer}"
104
 
105
 
106
  def setup_gradio():
107
+ desc = "Welcome to Trainline Train Time Q&A! A silly little demo of how RAG can ground the answers from an LLM."
108
+ long_desc = """
109
+ Ask a question in the like 'How many trains per day from Rome to Madrid'.
110
+
111
+ This info is scrapped from the table on train time pages. Example page here: https://www.thetrainline.com/train-times/manchester-to-london
112
+
113
+ Not all train time pages have been scrapped so you don't get answers for every route, just the ones in the supported routes below and only answers for the supported questions.
114
+
115
+ Supported questions are;
116
+ - Price from X to Y
117
+ - Last train from X to Y
118
+ - First train from X to Y
119
+ - Frequency of trains from X to Y
120
+ - Price of trains from X to Y
121
+ - Operators of trains and buses from X to Y
122
+ - Distance from X to Y
123
+ - Number of Changes from X to Y
124
+ - Journey time from X to Y
125
+
126
+ Supported routes are;
127
+ - "London to Edinburgh",
128
+ - "Madrid to Barcelona",
129
+ - "Rome to Madrid",
130
+ - "Barcelona to Madrid",
131
+ - "London to Madrid",
132
+ - "London to Manchester",
133
+ - "Leeds to London",
134
+ - "London to Birmingham",
135
+ - "London to Brighton",
136
+ - "Glasgow to Manchester",
137
+ - "Glasgow to Liverpool",
138
+ - "Glasgow to Leeds",
139
+ - "Birmingham to Glasgow",
140
+ - "London to Newcastle",
141
+ - "Seville to Madrid",
142
+ """
143
+
144
+ iface = gr.Interface(
145
+ ask_question,
146
+ inputs="text",
147
+ outputs="text",
148
+ allow_screenshot=False,
149
+ allow_flagging=False,
150
+ description=desc,
151
+ article=long_desc
152
  )
153
+ iface.launch()
154
 
 
155
 
156
  def load_docs():
157
  loader = TrainlineTrainTimeLoader(list(urls.keys()), urls_to_od_pair=urls)