lubani commited on
Commit
d54727c
·
1 Parent(s): 213d7a3
Files changed (2) hide show
  1. app.py +202 -0
  2. custom_llm.py +65 -0
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from llama_index.core import SimpleDirectoryReader
4
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
5
+ from llama_index.core import VectorStoreIndex
6
+ from custom_llm import CustomLLM
7
+ import gradio as gr
8
+ # import shutil
9
+ import tempfile
10
+
11
+ # default model
12
+ repo_id = "mistralai/Mistral-7B-Instruct-v0.1"
13
+ model_type = 'text-generation'
14
+
15
+ API_TOKEN = os.getenv('HF_INFER_API')
16
+ temp_dir = tempfile.TemporaryDirectory()
17
+
18
+
19
+ embedding_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
20
+ llm = CustomLLM(repo_id=repo_id, model_type=model_type, api_token=API_TOKEN)
21
+
22
+ def add_text(history, text):
23
+ history = history + [(text, None)]
24
+ return history, gr.Textbox(value="", interactive=False)
25
+
26
+
27
+ def hasFile(history):
28
+ for user_prompt, bot_response in history:
29
+ if '.pdf' in user_prompt.lower():
30
+ return True
31
+
32
+ return False
33
+
34
+ def modelChanged(history, drop):
35
+
36
+
37
+ history = history + [(f'===> {drop}', None)]
38
+ return history, drop
39
+
40
+
41
+ def getEngine(llm):
42
+ loader = SimpleDirectoryReader(
43
+ input_dir=temp_dir.name,
44
+ recursive=True,
45
+ required_exts=[".pdf", ".PDF"],
46
+ )
47
+
48
+ # Load files as documents
49
+ documents = loader.load_data()
50
+
51
+ # create an index in the memory
52
+ index = VectorStoreIndex.from_documents(
53
+ documents,
54
+ embed_model=embedding_model,
55
+ )
56
+
57
+ #create query_engine
58
+ query_engine = index.as_query_engine(llm=llm)
59
+ return query_engine
60
+
61
+ def copy_pdf(source_path, destination_path):
62
+
63
+ # Open the source PDF file in binary read mode
64
+ with open(source_path, "rb") as source_file:
65
+ # Read the entire content of the source file
66
+ data = source_file.read()
67
+
68
+ # Open the destination file in binary write mode
69
+ with open(destination_path, "wb") as destination_file:
70
+ # Write the copied data to the destination file
71
+ destination_file.write(data)
72
+
73
+ # Print a success message
74
+ print(f"PDF copied successfully from {source_path} to {destination_path}")
75
+
76
+
77
+ def add_file(history, file):
78
+
79
+ file_path = os.path.join(temp_dir.name, os.path.basename(file))
80
+ # shutil.copyfile(file.name, file_path) # <---Asynchronous
81
+ copy_pdf(file.name, file_path)
82
+
83
+
84
+ history = history + [(os.path.basename(file), None)]
85
+ return history
86
+
87
+
88
+ def format_prompt(message, history, model):
89
+ if model is None or 'mistral' in model.lower():
90
+ prompt = "<s>"
91
+ for user_prompt, bot_response in history:
92
+ prompt += f"[INST] {user_prompt} [/INST]"
93
+ prompt += f" {bot_response}</s> "
94
+ prompt += f"[INST] {message} [/INST]"
95
+
96
+ elif 'google' in model.lower():
97
+ prompt = "<bos>"
98
+ for user_prompt, bot_response in history:
99
+ prompt += f"<start_of_turn>user {user_prompt} <end_of_turn><start_of_turn>model {bot_response}"
100
+ prompt += f"<start_of_turn>user {message} <end_of_turn><start_of_turn>model"
101
+
102
+ else:
103
+ prompt = ""
104
+
105
+ return prompt
106
+
107
+ def bot(history, model=None):
108
+
109
+ print("===> model: ", model)
110
+ local_llm = llm
111
+ if model:
112
+ local_llm = CustomLLM(repo_id=model, model_type=model_type, api_token=API_TOKEN)
113
+
114
+ if len(history) > 0 and len(history[-1]) > 0 and '.pdf' in history[-1][0]:
115
+ response = "You uploaded a PDF file. You can ask questions from the file."
116
+
117
+ elif len(history) > 0 and len(history[-1]) > 0 and '===>' in history[-1][0]:
118
+ new_model = history[-1][0].replace("===>", "")
119
+ response = f"You have changed the model to {new_model}"
120
+
121
+ else:
122
+ prompt = history[-1][0]
123
+
124
+ if hasFile(history):
125
+ query_engine = getEngine(local_llm)
126
+ response = query_engine.query(prompt)
127
+
128
+ print("Response from file")
129
+ else:
130
+ response = local_llm.predict(format_prompt(prompt, history, model))
131
+ print("Response from Model")
132
+
133
+
134
+
135
+ # print(response)
136
+
137
+ # response = "Thats cool!"
138
+
139
+ history[-1][1] = ""
140
+ for character in str(response):
141
+ history[-1][1] += character
142
+ # time.sleep(0.05)
143
+ yield history
144
+
145
+
146
+ with gr.Blocks() as demo:
147
+
148
+ gr.Markdown(
149
+ """
150
+ <div style="display: grid; justify-content: center;">
151
+ <h1>Basic RAG with Huggingface Inference API</h1>
152
+ <h4>For best performance start with small PDF files (less than 20 pages). </h4>
153
+ </div>
154
+ """
155
+ )
156
+
157
+ chatbot = gr.Chatbot(
158
+ [],
159
+ elem_id="chatbot",
160
+ bubble_full_width=False,
161
+ # avatar_images=(None, (os.path.join(os.path.dirname(__file__), "avatar.png"))),
162
+ )
163
+
164
+ with gr.Row():
165
+ drop = gr.Dropdown(
166
+ [
167
+ ("Mixtral-8x7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1"),
168
+ ("Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.2"),
169
+ ("gemma-7b-it", "google/gemma-7b-it"),
170
+ ("gemma-2b-it", "google/gemma-2b-it")
171
+ ],
172
+ value="mistralai/Mixtral-8x7B-Instruct-v0.1",
173
+ label="Model",
174
+ info=""
175
+ )
176
+
177
+ with gr.Row():
178
+ txt = gr.Textbox(
179
+ scale=4,
180
+ show_label=False,
181
+ placeholder="Type your question and press enter",
182
+ container=False,
183
+ )
184
+ btn = gr.UploadButton("📁", file_types=[".pdf"])
185
+
186
+ drop.change(modelChanged, [chatbot, drop], [chatbot, drop], queue=False).then(
187
+ bot, [chatbot, drop], chatbot
188
+ )
189
+
190
+ txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
191
+ bot, [chatbot, drop], chatbot, api_name="bot_response"
192
+ )
193
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
194
+
195
+ file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(
196
+ bot, [chatbot, drop], chatbot
197
+ )
198
+
199
+
200
+
201
+ demo.queue()
202
+ demo.launch()
custom_llm.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.language_models.llms import LLM
2
+ from langchain_core.callbacks.manager import CallbackManagerForLLMRun
3
+ import requests
4
+ from typing import Any, List, Mapping, Optional, Literal
5
+
6
+ class CustomLLM(LLM):
7
+ #Properties
8
+ repo_id : str
9
+ api_token : str
10
+ model_type: Literal["text2text-generation", "text-generation"]
11
+ max_new_tokens: int = None
12
+ temperature: float = 0.001
13
+ timeout: float = None
14
+ top_p: float = None
15
+ top_k : int = None
16
+ repetition_penalty : float = None
17
+ stop : List[str] = []
18
+
19
+
20
+ @property
21
+ def _llm_type(self) -> str:
22
+ return "custom"
23
+
24
+ def _call(
25
+ self,
26
+ prompt: str,
27
+ stop: Optional[List[str]] = None,
28
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
29
+ **kwargs: Any,
30
+ ) -> str:
31
+
32
+ headers = {"Authorization": f"Bearer {self.api_token}"}
33
+ API_URL = f"https://api-inference.huggingface.co/models/{self.repo_id}"
34
+
35
+ parameters_dict = {
36
+ 'max_new_tokens': self.max_new_tokens,
37
+ 'temperature': self.temperature,
38
+ 'timeout': self.timeout,
39
+ 'top_p': self.top_p,
40
+ 'top_k': self.top_k,
41
+ 'repetition_penalty': self.repetition_penalty,
42
+ 'stop':self.stop
43
+ }
44
+
45
+ if self.model_type == 'text-generation':
46
+ parameters_dict["return_full_text"]=False
47
+
48
+ data = {"inputs": prompt, "parameters":parameters_dict, "options":{"wait_for_model":True}}
49
+ data = requests.post(API_URL, headers=headers, json=data).json()
50
+ return data[0]['generated_text']
51
+
52
+ @property
53
+ def _identifying_params(self) -> Mapping[str, Any]:
54
+ """Get the identifying parameters."""
55
+ return {
56
+ 'repo_id': self.repo_id,
57
+ 'model_type':self.model_type,
58
+ 'stop_sequences':self.stop,
59
+ 'max_new_tokens': self.max_new_tokens,
60
+ 'temperature': self.temperature,
61
+ 'timeout': self.timeout,
62
+ 'top_p': self.top_p,
63
+ 'top_k': self.top_k,
64
+ 'repetition_penalty': self.repetition_penalty
65
+ }