ksimdeep commited on
Commit
7e40afd
·
verified ·
1 Parent(s): 5dd5eb7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -0
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch # Add missing import
2
+ import streamlit as st
3
+ import os
4
+ import tempfile
5
+ from langchain_community.document_loaders import (
6
+ TextLoader,
7
+ CSVLoader,
8
+ UnstructuredFileLoader
9
+ )
10
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
11
+ from langchain_community.embeddings import HuggingFaceEmbeddings
12
+ from langchain_community.retrievers import BM25Retriever
13
+ from langchain.retrievers import EnsembleRetriever
14
+ from transformers import (
15
+ AutoTokenizer,
16
+ AutoModelForCausalLM,
17
+ BitsAndBytesConfig,
18
+ pipeline
19
+ )
20
+
21
+
22
+ # Configuration
23
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
24
+ EMBEDDING_MODEL = "thenlper/gte-large"
25
+ CHUNK_SIZE = 1024
26
+ CHUNK_OVERLAP = 128
27
+ MAX_NEW_TOKENS = 2048
28
+
29
+ # Initialize session state
30
+ if "messages" not in st.session_state:
31
+ st.session_state.messages = []
32
+
33
+ @st.cache_resource
34
+ def initialize_model():
35
+ quantization_config = BitsAndBytesConfig(
36
+ load_in_4bit=True,
37
+ bnb_4bit_compute_dtype=torch.float16,
38
+ bnb_4bit_quant_type="nf4",
39
+ bnb_4bit_use_double_quant=True
40
+ )
41
+
42
+ # Load config first to modify RoPE params
43
+ from transformers import AutoConfig
44
+ config = AutoConfig.from_pretrained(
45
+ MODEL_NAME,
46
+ trust_remote_code=True
47
+ )
48
+
49
+ # Fix RoPE scaling configuration
50
+ if hasattr(config, "rope_scaling"):
51
+ config.rope_scaling = {
52
+ "type": config.rope_scaling.get("rope_type", "linear"),
53
+ "factor": config.rope_scaling.get("factor", 8.0)
54
+ }
55
+
56
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ MODEL_NAME,
59
+ config=config,
60
+ quantization_config=quantization_config,
61
+ device_map="auto",
62
+ trust_remote_code=True
63
+ )
64
+
65
+ return pipeline(
66
+ "text-generation",
67
+ model=model,
68
+ tokenizer=tokenizer,
69
+ device_map="auto",
70
+ max_new_tokens=MAX_NEW_TOKENS,
71
+ temperature=0.1
72
+ )
73
+
74
+ def process_uploaded_files(uploaded_files):
75
+ documents = []
76
+ with tempfile.TemporaryDirectory() as temp_dir:
77
+ for file in uploaded_files:
78
+ temp_path = os.path.join(temp_dir, file.name)
79
+ with open(temp_path, "wb") as f:
80
+ f.write(file.getbuffer())
81
+
82
+ try:
83
+ if file.name.endswith(".txt"):
84
+ loader = TextLoader(temp_path)
85
+ elif file.name.endswith(".csv"):
86
+ loader = CSVLoader(temp_path)
87
+ else:
88
+ loader = UnstructuredFileLoader(temp_path)
89
+ documents.extend(loader.load())
90
+ except Exception as e:
91
+ st.error(f"Error loading {file.name}: {str(e)}")
92
+
93
+ text_splitter = RecursiveCharacterTextSplitter(
94
+ chunk_size=CHUNK_SIZE,
95
+ chunk_overlap=CHUNK_OVERLAP,
96
+ length_function=len
97
+ )
98
+ return text_splitter.split_documents(documents)
99
+
100
+ def create_retriever(documents):
101
+ embeddings = HuggingFaceEmbeddings(
102
+ model_name=EMBEDDING_MODEL,
103
+ model_kwargs={'device': 'cuda'},
104
+ encode_kwargs={'normalize_embeddings': True}
105
+ )
106
+
107
+ bm25_retriever = BM25Retriever.from_documents(documents)
108
+ bm25_retriever.k = st.session_state.get("top_k", 5)
109
+
110
+ return EnsembleRetriever(
111
+ retrievers=[bm25_retriever],
112
+ weights=[0.5]
113
+ )
114
+
115
+
116
+ def generate_response(query, retriever, generator):
117
+ docs = retriever.get_relevant_documents(query)
118
+ context = "\n\n".join(
119
+ f"[Doc{i+1}] {doc.page_content}\nSource: {doc.metadata.get('source', 'unknown')}"
120
+ for i, doc in enumerate(docs)
121
+ )
122
+
123
+ prompt = f"""<s>[INST] You are a precision-focused research assistant tasked with answering queries based solely on the provided context.
124
+
125
+ **Context:**
126
+ {context}
127
+
128
+ **Query:**
129
+ {query}
130
+
131
+ **Response Instructions:**
132
+ - Write a detailed, coherent, and insightful article that fully addresses the query based on the provided context.
133
+ - Adhere to the following principles:
134
+ 1. **Define the Core Subject**: Introduce and build the discussion logically around the main topic.
135
+ 2. **Establish Connections**: Highlight relationships between ideas and concepts with reasoning and examples.
136
+ 3. **Elaborate on Key Points**: Provide in-depth explanations and emphasize the significance of concepts.
137
+ 4. **Maintain Objectivity**: Use only the context provided, avoiding speculation or external knowledge.
138
+ 5. **Ensure Structure and Clarity**: Present information sequentially for a smooth narrative flow.
139
+ 6. **Engage with Content**: Explore implicit meanings, resolve doubts, and address counterpoints logically.
140
+ 7. **Provide Examples and Insights**: Use examples to clarify abstract ideas and offer actionable steps if applicable.
141
+ 8. **Logical Depth**: Draw inferences, explain purposes, and refute opposing ideas when necessary.
142
+ - Cite sources explicitly as [Doc1], [Doc2], etc.
143
+ - If uncertain, state: "I cannot determine from the provided context."
144
+
145
+ Craft the response as a seamless, thorough, and authoritative explanation that naturally integrates all aspects of the query. [/INST]"""
146
+
147
+ response = generator(
148
+ prompt,
149
+ pad_token_id=generator.tokenizer.eos_token_id,
150
+ do_sample=True
151
+ )[0]['generated_text']
152
+
153
+ return response.split("[/INST]")[-1].strip(), docs
154
+
155
+
156
+ # def generate_response(query, retriever, generator):
157
+ # docs = retriever.get_relevant_documents(query)
158
+ # context = "\n\n".join(
159
+ # f"[Doc{i+1}] {doc.page_content}\nSource: {doc.metadata.get('source', 'unknown')}"
160
+ # for i, doc in enumerate(docs)
161
+ # )
162
+
163
+ # prompt = f"""<s>[INST] You are a precise research assistant. Use ONLY the provided context:
164
+
165
+ # {context}
166
+
167
+ # Question: {query}
168
+
169
+ # Answer with:
170
+ # 1. Direct facts from context
171
+ # 2. NO speculation
172
+ # 3. Cite sources like [Doc1]
173
+ # 4. If unsure, say "I cannot determine this from the provided data" [/INST]"""
174
+
175
+ # response = generator(
176
+ # prompt,
177
+ # pad_token_id=generator.tokenizer.eos_token_id,
178
+ # do_sample=True
179
+ # )[0]['generated_text']
180
+
181
+ # return response.split("[/INST]")[-1].strip(), docs
182
+
183
+ # Streamlit UI
184
+ st.title("📚 Document-Based QA Assistant")
185
+ st.markdown("Upload your documents and ask questions!")
186
+
187
+ # Sidebar controls
188
+ with st.sidebar:
189
+ st.header("Configuration")
190
+ uploaded_files = st.file_uploader(
191
+ "Upload documents (TXT)",
192
+ type=["txt", "csv"],
193
+ accept_multiple_files=True
194
+ )
195
+ st.session_state.top_k = st.slider("Number of documents to retrieve", 3, 10, 5)
196
+ st.markdown("---")
197
+ st.markdown("Powered by Mistral-7B and LangChain")
198
+
199
+ # Main chat interface
200
+ for message in st.session_state.messages:
201
+ with st.chat_message(message["role"]):
202
+ st.markdown(message["content"])
203
+ if "sources" in message:
204
+ with st.expander("View Sources"):
205
+ for i, doc in enumerate(message["sources"]):
206
+ st.markdown(f"**Doc{i+1}** ({doc.metadata.get('source', 'unknown')})")
207
+ st.info(doc.page_content)
208
+
209
+ # Process documents
210
+ if uploaded_files and "retriever" not in st.session_state:
211
+ with st.spinner("Processing documents..."):
212
+ documents = process_uploaded_files(uploaded_files)
213
+ st.session_state.retriever = create_retriever(documents)
214
+ st.session_state.generator = initialize_model()
215
+
216
+ if prompt := st.chat_input("Ask a question about your documents"):
217
+ # Add user message
218
+ st.session_state.messages.append({"role": "user", "content": prompt})
219
+ with st.chat_message("user"):
220
+ st.markdown(prompt)
221
+
222
+ # Generate response
223
+ if "retriever" not in st.session_state:
224
+ st.error("Please upload documents first!")
225
+ st.stop()
226
+
227
+ with st.spinner("Analyzing documents..."):
228
+ try:
229
+ response, sources = generate_response(
230
+ prompt,
231
+ st.session_state.retriever,
232
+ st.session_state.generator
233
+ )
234
+
235
+ # Add assistant response
236
+ st.session_state.messages.append({
237
+ "role": "assistant",
238
+ "content": response,
239
+ "sources": sources
240
+ })
241
+
242
+ # Display response
243
+ with st.chat_message("assistant"):
244
+ st.markdown(response)
245
+ with st.expander("View Document Sources"):
246
+ for i, doc in enumerate(sources):
247
+ st.markdown(f"**Doc{i+1}** ({doc.metadata.get('source', 'unknown')})")
248
+ st.info(doc.page_content)
249
+
250
+ except Exception as e:
251
+ st.error(f"Error generating response: {str(e)}")