Can Günen commited on
Commit
737f413
·
1 Parent(s): 22f5a4b

divided py files and added more error catching

Browse files
Files changed (2) hide show
  1. app.py +15 -75
  2. document_chatbot.py +75 -0
app.py CHANGED
@@ -2,68 +2,7 @@ import os
2
  import random
3
  import time
4
  import gradio as gr
5
- import subprocess
6
- import requests
7
- from langchain.chains.question_answering import load_qa_chain
8
- from langchain.text_splitter import CharacterTextSplitter
9
- from langchain.embeddings import HuggingFaceEmbeddings
10
- from langchain.docstore.document import Document
11
- from langchain.document_loaders import TextLoader
12
- from langchain.vectorstores import FAISS
13
- from langchain import HuggingFaceHub
14
-
15
-
16
- class DocumentChatbot:
17
- def __init__(self):
18
- self.llm = None
19
- self.chain = None
20
- self.embeddings = None
21
- self.metadata = {"source": "internet"}
22
- self.init_mes = ["According to the document, ", "Based on the text, ", "I think, ", "According to the text, ", "Based on the document you provided, "]
23
-
24
- def load_model(self, api_key):
25
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
26
- result = subprocess.run(["curl", "https://huggingface.co/api/whoami-v2", "-H", f"Authorization: Bearer {api_key}"], capture_output=True).stdout.decode()
27
- self.llm = HuggingFaceHub(repo_id="google/flan-t5-large", model_kwargs={"temperature":0, "max_length":512})
28
- self.chain = load_qa_chain(self.llm, chain_type="stuff")
29
- self.embeddings = HuggingFaceEmbeddings()
30
- if result == '{"error":"Invalid username or password."}':
31
- return "Invalid API token"
32
- else:
33
- return "HF Token successfully registered"
34
-
35
-
36
- def respond(self, text_input, question, chat_history):
37
- if text_input.startswith("http"):
38
- response = requests.get(text_input)
39
- text_var = response.text
40
- else:
41
- text_var = text_input
42
-
43
- time.sleep(0.5)
44
-
45
- documents = [Document(page_content=text_var, metadata=self.metadata)]
46
- text_splitter = CharacterTextSplitter(chunk_size=750, chunk_overlap=0)
47
- docs = text_splitter.split_documents(documents)
48
-
49
- if self.llm is None:
50
- raise ValueError("Model not loaded")
51
-
52
- db = FAISS.from_documents(docs, self.embeddings)
53
- query = question
54
-
55
- start_time = time.monotonic()
56
- try:
57
- docs = db.similarity_search(query)
58
- answer = self.chain.run(input_documents=docs, question=query, max_execution_time=5)
59
- except ValueError as e:
60
- answer = f"An error occurred: {str(e)}"
61
-
62
- bot_message = random.choice(self.init_mes) + answer + "."
63
- chat_history.append((question, bot_message))
64
- time.sleep(1)
65
- return "", chat_history
66
-
67
 
68
  document_chatbot = DocumentChatbot()
69
 
@@ -71,18 +10,19 @@ with gr.Blocks() as demo:
71
  title = """<p><h1 align="center" style="font-size: 36px;">Talk with your document</h1></p>"""
72
  gr.HTML(title)
73
  with gr.Row():
74
- text_input = gr.Textbox(label="Enter text or URL to text file")
75
- with gr.Column():
76
- api_key_input = gr.Textbox(label="Enter HF Token to load the model")
77
- api_key_input.submit(document_chatbot.load_model, inputs=api_key_input, outputs=api_key_input)
78
- chatbot = gr.Chatbot()
 
 
 
 
 
 
 
 
79
 
80
- q_input = gr.Textbox(label="Please write your question")
81
- clear = gr.Button("Clear")
82
- q_input.submit(document_chatbot.respond, [text_input, q_input, chatbot], [q_input, chatbot])
83
- clear.click(lambda: None, None, chatbot, queue=False)
84
-
85
-
86
-
87
 
88
- demo.launch(debug=True)
 
2
  import random
3
  import time
4
  import gradio as gr
5
+ from document_chatbot import DocumentChatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  document_chatbot = DocumentChatbot()
8
 
 
10
  title = """<p><h1 align="center" style="font-size: 36px;">Talk with your document</h1></p>"""
11
  gr.HTML(title)
12
  with gr.Row():
13
+ text_input = gr.Textbox(label="Enter text or URL to text file")
14
+ with gr.Column():
15
+ with gr.Row():
16
+ api_key_input = gr.Textbox(label="Enter HF Token to load the model")
17
+ api_key_input.submit(document_chatbot.load_token, inputs=api_key_input, outputs=api_key_input)
18
+ picked_model = gr.Dropdown(["google/flan-t5-large", "google/flan-t5-base","google/flan-t5-small"], label="Models", info="I'd recommend choosing the first one")
19
+ picked_model.change(document_chatbot.load_model, picked_model)
20
+ chatbot = gr.Chatbot()
21
+
22
+ q_input = gr.Textbox(label="Please write your question")
23
+ clear = gr.Button("Clear")
24
+ q_input.submit(document_chatbot.respond, [text_input, q_input, chatbot], [q_input, chatbot])
25
+ clear.click(lambda: None, None, chatbot, queue=False)
26
 
 
 
 
 
 
 
 
27
 
28
+ demo.launch(debug=True)
document_chatbot.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import subprocess
5
+ import requests
6
+ from langchain.chains.question_answering import load_qa_chain
7
+ from langchain.text_splitter import CharacterTextSplitter
8
+ from langchain.embeddings import HuggingFaceEmbeddings
9
+ from langchain.docstore.document import Document
10
+ from langchain.document_loaders import TextLoader
11
+ from langchain.vectorstores import FAISS
12
+ from langchain import HuggingFaceHub
13
+
14
+
15
+ class DocumentChatbot:
16
+
17
+ def __init__(self):
18
+ self.llm = None
19
+ self.chain = None
20
+ self.embeddings = None
21
+ self.metadata = {"source": "internet"}
22
+ self.init_mes = ["According to the document, ", "Based on the text, ", "I think, ", "According to the text, ", "Based on the document you provided, "]
23
+
24
+ def load_token(self, api_key):
25
+ if api_key[:2] == "hf":
26
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
27
+ result = subprocess.run(["curl", "https://huggingface.co/api/whoami-v2", "-H", f"Authorization: Bearer {api_key}"], capture_output=True).stdout.decode()
28
+ if result == '{"error":"Invalid username or password."}':
29
+ return "Invalid API token"
30
+ else:
31
+ return "HF Token successfully registered"
32
+
33
+ def load_model(self, model_name):
34
+ self.llm = HuggingFaceHub(repo_id=model_name, model_kwargs={"temperature":0, "max_length":512})
35
+ self.chain = load_qa_chain(self.llm, chain_type="stuff")
36
+ self.embeddings = HuggingFaceEmbeddings()
37
+ return f"Successfully loaded {model_name}"
38
+
39
+
40
+ def respond(self, text_input, question, chat_history):
41
+ if text_input.startswith("http"):
42
+ response = requests.get(text_input)
43
+ text_var = response.text
44
+ if text_var is None:
45
+ raise ValueError("No document is given")
46
+ else:
47
+ text_var = text_input
48
+
49
+
50
+
51
+ time.sleep(0.5)
52
+
53
+ documents = [Document(page_content=text_var, metadata=self.metadata)]
54
+ text_splitter = CharacterTextSplitter(chunk_size=750, chunk_overlap=0)
55
+ docs = text_splitter.split_documents(documents)
56
+
57
+ if self.llm is None:
58
+ raise ValueError("Model not loaded")
59
+
60
+
61
+
62
+ db = FAISS.from_documents(docs, self.embeddings)
63
+ query = question
64
+
65
+ try:
66
+ docs = db.similarity_search(query)
67
+ answer = self.chain.run(input_documents=docs, question=query)
68
+ bot_message = random.choice(self.init_mes) + answer + "."
69
+ except ValueError as e:
70
+ bot_message = f"An error occurred: {str(e)}"
71
+ chat_history.append((question, bot_message))
72
+
73
+ time.sleep(1)
74
+
75
+ return "", chat_history