Spaces:
Runtime error
Runtime error
Adam Fallon commited on
Commit ·
6f732f7
1
Parent(s): 0e8964f
simplify and add description
Browse files
app.py
CHANGED
|
@@ -8,13 +8,14 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
| 8 |
from langchain.llms import HuggingFaceHub
|
| 9 |
from tl_loaders.TrainlineTrainTimeLoader import TrainlineTrainTimeLoader
|
| 10 |
from langchain.chains import RetrievalQA
|
|
|
|
| 11 |
import gradio as gr
|
| 12 |
|
| 13 |
from dotenv import load_dotenv
|
| 14 |
|
| 15 |
load_dotenv()
|
| 16 |
embedding_model = "sentence-transformers/all-mpnet-base-v2"
|
| 17 |
-
persist_directory = "docs/chroma/
|
| 18 |
chunk_size = 1000
|
| 19 |
chunk_overlap = 0
|
| 20 |
|
|
@@ -33,7 +34,7 @@ greet = "Ask a question in the like 'How many trains per day from Rome to Madrid
|
|
| 33 |
newline = "\n"
|
| 34 |
llm = None
|
| 35 |
db = None
|
| 36 |
-
force_reindex =
|
| 37 |
chat_history = []
|
| 38 |
|
| 39 |
# Few random ones and top results from https://www.thetrainline.com/train-times
|
|
@@ -56,14 +57,33 @@ urls = {
|
|
| 56 |
}
|
| 57 |
|
| 58 |
|
| 59 |
-
def ask_question(message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
qa = RetrievalQA.from_chain_type(
|
| 61 |
llm=llm,
|
| 62 |
chain_type="stuff",
|
| 63 |
retriever=db.as_retriever(search_kwargs=search_kwargs),
|
| 64 |
return_source_documents=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
)
|
| 66 |
|
|
|
|
|
|
|
| 67 |
# result = qa(f"{newline.join(chat_history)}\n[INST]{message}[/INST]")
|
| 68 |
result = qa(f"[INST]{message}[/INST]")
|
| 69 |
|
|
@@ -77,26 +97,61 @@ def ask_question(message, history):
|
|
| 77 |
else:
|
| 78 |
try:
|
| 79 |
source = result["source_documents"][0].metadata["source"]
|
| 80 |
-
|
|
|
|
| 81 |
except:
|
| 82 |
return f"{answer}"
|
| 83 |
|
| 84 |
|
| 85 |
def setup_gradio():
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
)
|
|
|
|
| 98 |
|
| 99 |
-
demo.launch()
|
| 100 |
|
| 101 |
def load_docs():
|
| 102 |
loader = TrainlineTrainTimeLoader(list(urls.keys()), urls_to_od_pair=urls)
|
|
|
|
| 8 |
from langchain.llms import HuggingFaceHub
|
| 9 |
from tl_loaders.TrainlineTrainTimeLoader import TrainlineTrainTimeLoader
|
| 10 |
from langchain.chains import RetrievalQA
|
| 11 |
+
from langchain.prompts import PromptTemplate
|
| 12 |
import gradio as gr
|
| 13 |
|
| 14 |
from dotenv import load_dotenv
|
| 15 |
|
| 16 |
load_dotenv()
|
| 17 |
embedding_model = "sentence-transformers/all-mpnet-base-v2"
|
| 18 |
+
persist_directory = "docs/chroma/"
|
| 19 |
chunk_size = 1000
|
| 20 |
chunk_overlap = 0
|
| 21 |
|
|
|
|
| 34 |
newline = "\n"
|
| 35 |
llm = None
|
| 36 |
db = None
|
| 37 |
+
force_reindex = False
|
| 38 |
chat_history = []
|
| 39 |
|
| 40 |
# Few random ones and top results from https://www.thetrainline.com/train-times
|
|
|
|
| 57 |
}
|
| 58 |
|
| 59 |
|
| 60 |
+
def ask_question(message):
|
| 61 |
+
prompt = """
|
| 62 |
+
Use the following pieces of context to answer the question at the end.
|
| 63 |
+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
| 64 |
+
DO NOT RAMBLE or try to infer information.
|
| 65 |
+
Just give the exact information requested.
|
| 66 |
+
Take a deep breath and work on this problem step-by-step.
|
| 67 |
+
|
| 68 |
+
{context}
|
| 69 |
+
|
| 70 |
+
Question: {question}
|
| 71 |
+
"""
|
| 72 |
qa = RetrievalQA.from_chain_type(
|
| 73 |
llm=llm,
|
| 74 |
chain_type="stuff",
|
| 75 |
retriever=db.as_retriever(search_kwargs=search_kwargs),
|
| 76 |
return_source_documents=True,
|
| 77 |
+
chain_type_kwargs={
|
| 78 |
+
"prompt": PromptTemplate(
|
| 79 |
+
template=prompt,
|
| 80 |
+
input_variables=["context", "question"],
|
| 81 |
+
),
|
| 82 |
+
},
|
| 83 |
)
|
| 84 |
|
| 85 |
+
print(qa.combine_documents_chain.llm_chain.prompt.template)
|
| 86 |
+
|
| 87 |
# result = qa(f"{newline.join(chat_history)}\n[INST]{message}[/INST]")
|
| 88 |
result = qa(f"[INST]{message}[/INST]")
|
| 89 |
|
|
|
|
| 97 |
else:
|
| 98 |
try:
|
| 99 |
source = result["source_documents"][0].metadata["source"]
|
| 100 |
+
doc = result["source_documents"][0].page_content
|
| 101 |
+
return f"{answer}\n---\nInformation found in text: {doc}\nSource: {source}"
|
| 102 |
except:
|
| 103 |
return f"{answer}"
|
| 104 |
|
| 105 |
|
| 106 |
def setup_gradio():
|
| 107 |
+
desc = "Welcome to Trainline Train Time Q&A! A silly little demo of how RAG can ground the answers from an LLM."
|
| 108 |
+
long_desc = """
|
| 109 |
+
Ask a question in the like 'How many trains per day from Rome to Madrid'.
|
| 110 |
+
|
| 111 |
+
This info is scrapped from the table on train time pages. Example page here: https://www.thetrainline.com/train-times/manchester-to-london
|
| 112 |
+
|
| 113 |
+
Not all train time pages have been scrapped so you don't get answers for every route, just the ones in the supported routes below and only answers for the supported questions.
|
| 114 |
+
|
| 115 |
+
Supported questions are;
|
| 116 |
+
- Price from X to Y
|
| 117 |
+
- Last train from X to Y
|
| 118 |
+
- First train from X to Y
|
| 119 |
+
- Frequency of trains from X to Y
|
| 120 |
+
- Price of trains from X to Y
|
| 121 |
+
- Operators of trains and buses from X to Y
|
| 122 |
+
- Distance from X to Y
|
| 123 |
+
- Number of Changes from X to Y
|
| 124 |
+
- Journey time from X to Y
|
| 125 |
+
|
| 126 |
+
Supported routes are;
|
| 127 |
+
- "London to Edinburgh",
|
| 128 |
+
- "Madrid to Barcelona",
|
| 129 |
+
- "Rome to Madrid",
|
| 130 |
+
- "Barcelona to Madrid",
|
| 131 |
+
- "London to Madrid",
|
| 132 |
+
- "London to Manchester",
|
| 133 |
+
- "Leeds to London",
|
| 134 |
+
- "London to Birmingham",
|
| 135 |
+
- "London to Brighton",
|
| 136 |
+
- "Glasgow to Manchester",
|
| 137 |
+
- "Glasgow to Liverpool",
|
| 138 |
+
- "Glasgow to Leeds",
|
| 139 |
+
- "Birmingham to Glasgow",
|
| 140 |
+
- "London to Newcastle",
|
| 141 |
+
- "Seville to Madrid",
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
iface = gr.Interface(
|
| 145 |
+
ask_question,
|
| 146 |
+
inputs="text",
|
| 147 |
+
outputs="text",
|
| 148 |
+
allow_screenshot=False,
|
| 149 |
+
allow_flagging=False,
|
| 150 |
+
description=desc,
|
| 151 |
+
article=long_desc
|
| 152 |
)
|
| 153 |
+
iface.launch()
|
| 154 |
|
|
|
|
| 155 |
|
| 156 |
def load_docs():
|
| 157 |
loader = TrainlineTrainTimeLoader(list(urls.keys()), urls_to_od_pair=urls)
|