amira01 commited on
Commit
91e0ac4
·
verified ·
1 Parent(s): 19c02f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -50
app.py CHANGED
@@ -1,64 +1,202 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
27
 
28
- response = ""
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
1
  import gradio as gr
2
+ from langchain_community.document_loaders import PyPDFLoader
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+ from langchain_community.vectorstores import FAISS
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.chains import RetrievalQA
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain_community.chat_models import ChatOpenAI
9
+ import os
10
+ from tempfile import NamedTemporaryFile
11
 
12
+ # Load API Key
13
+ def load_api_key():
14
+ if "OPENROUTER_API_KEY" in os.environ:
15
+ return os.getenv("OPENROUTER_API_KEY")
16
+ raise ValueError("API key not found in environment variables")
17
 
18
+ # Process PDF files
19
+ def process_pdfs(files):
20
+ all_chunks = []
21
+ for file_info in files: # file_info is a Gradio File object
22
+ with NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
23
+ # Copy file content to temporary file
24
+ with open(file_info.name, "rb") as f:
25
+ tmp_file.write(f.read())
26
+ tmp_file_path = tmp_file.name
27
+
28
+ try:
29
+ loader = PyPDFLoader(tmp_file_path)
30
+ pages = loader.load()
31
+
32
+ text_splitter = RecursiveCharacterTextSplitter(
33
+ chunk_size=1000,
34
+ chunk_overlap=200,
35
+ length_function=len
36
+ )
37
+ chunks = text_splitter.split_documents(pages)
38
+ all_chunks.extend(chunks)
39
+ finally:
40
+ os.unlink(tmp_file_path)
41
+
42
+ if not all_chunks:
43
+ raise ValueError("No content was loaded from the files")
44
+
45
+ embeddings = HuggingFaceEmbeddings(
46
+ model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
47
+ )
48
+ vectorstore = FAISS.from_documents(all_chunks, embeddings)
49
+ return vectorstore.as_retriever(search_kwargs={"k": 3})
50
 
51
+ # Initialize language model
52
+ def load_model():
53
+ return ChatOpenAI(
54
+ base_url="https://openrouter.ai/api/v1",
55
+ api_key="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",
56
+ model="mistralai/mistral-7b-instruct",
57
+ temperature=0.3
58
+ )
 
59
 
60
+ # Prompt template
61
+ template = """
62
+ You are an intelligent assistant specialized in document analysis.
63
+ Use the following information from PDF files to answer the question:
 
64
 
65
+ {context}
66
 
67
+ Question: {question}
68
+ Answer (in detail and in clear language):
69
+ """
70
 
71
+ prompt = PromptTemplate(
72
+ input_variables=["context", "question"],
73
+ template=template
74
+ )
 
 
 
 
75
 
76
+ # Application state variables
77
+ qa_chain = None
78
+ chat_history = []
79
 
80
+ # Function to process messages and respond
81
+ def respond(message, chat_history):
82
+ global qa_chain
83
+
84
+ if qa_chain is None:
85
+ return chat_history + [(message, "Please upload PDF files first")]
86
+
87
+ try:
88
+ result = qa_chain({"query": message})
89
+ response = result["result"]
90
+ return chat_history + [(message, response)]
91
+ except Exception as e:
92
+ return chat_history + [(message, f"An error occurred: {str(e)}")]
93
 
94
+ # Handle file uploads
95
+ def handle_upload(files):
96
+ global qa_chain
97
+
98
+ try:
99
+ file_contents = [file.read() for file in files]
100
+ retriever = process_pdfs(file_contents)
101
+ llm = load_model()
102
+
103
+ qa_chain = RetrievalQA.from_chain_type(
104
+ llm=llm,
105
+ retriever=retriever,
106
+ chain_type="stuff",
107
+ chain_type_kwargs={"prompt": prompt},
108
+ return_source_documents=False
109
+ )
110
+ return "Files uploaded successfully! You can now ask questions."
111
+ except Exception as e:
112
+ return f"Error uploading files: {str(e)}"
113
 
114
+ # UI
115
+ with gr.Blocks(title="Smart Document Assistant", theme=gr.themes.Default()) as demo:
116
+ gr.Markdown("# 📄 Smart Document Assistant")
117
+ gr.Markdown("Upload PDF files then start chatting")
118
+
119
+ # Chat section
120
+ chatbot = gr.Chatbot(height=500)
121
+
122
+ # Input section
123
+ with gr.Row():
124
+ msg = gr.Textbox(
125
+ placeholder="Type your question here...",
126
+ show_label=False,
127
+ scale=4
128
+ )
129
+ submit_btn = gr.Button("Send", scale=1)
130
+
131
+ # File section
132
+ with gr.Row():
133
+ file_upload = gr.Files(
134
+ label="Upload PDF files",
135
+ file_types=[".pdf"],
136
+ file_count="multiple"
137
+ )
138
+ upload_status = gr.Textbox(label="Upload Status", interactive=False)
139
+
140
+ clear_btn = gr.Button("Clear Chat")
141
+
142
+ # Event handling
143
+ def handle_upload(files):
144
+ global qa_chain
145
+ try:
146
+ retriever = process_pdfs(files)
147
+ llm = load_model()
148
+
149
+ qa_chain = RetrievalQA.from_chain_type(
150
+ llm=llm,
151
+ retriever=retriever,
152
+ chain_type="stuff",
153
+ chain_type_kwargs={
154
+ "prompt": PromptTemplate(
155
+ template=template,
156
+ input_variables=["context", "question"]
157
+ )
158
+ },
159
+ return_source_documents=False
160
+ )
161
+ return "Files uploaded successfully!"
162
+ except Exception as e:
163
+ return f"Error uploading files: {str(e)}"
164
+
165
+ file_upload.change(
166
+ handle_upload,
167
+ inputs=file_upload,
168
+ outputs=upload_status
169
+ )
170
+
171
+ submit_btn.click(
172
+ respond,
173
+ inputs=[msg, chatbot],
174
+ outputs=[chatbot]
175
+ ).then(
176
+ lambda: "",
177
+ None,
178
+ [msg]
179
+ )
180
+
181
+ msg.submit(
182
+ respond,
183
+ inputs=[msg, chatbot],
184
+ outputs=[chatbot]
185
+ ).then(
186
+ lambda: "",
187
+ None,
188
+ [msg]
189
+ )
190
+
191
+ clear_btn.click(
192
+ lambda: [],
193
+ None,
194
+ [chatbot]
195
+ )
196
 
197
  if __name__ == "__main__":
198
+ demo.launch(
199
+ server_name="0.0.0.0",
200
+ server_port=7860,
201
+ share=True # To get a public link
202
+ )