Mohamed284 commited on
Commit
e2e6886
·
verified ·
1 Parent(s): 7ea9067

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -273
app.py CHANGED
@@ -1,276 +1,3 @@
1
- # Optimized RAG System with E5-Mistral Embeddings and Gemini 2.0 Flash Generation
2
- import json
3
- import logging
4
- import re
5
- import os
6
- import pickle
7
- from typing import List, Tuple, Optional
8
- import gradio as gr
9
- from openai import OpenAI
10
- from google import genai
11
- from functools import lru_cache
12
- from tenacity import retry, stop_after_attempt, wait_exponential
13
- from langchain_community.retrievers import BM25Retriever
14
- from langchain_community.vectorstores import FAISS
15
- from langchain_core.embeddings import Embeddings
16
- from langchain_core.documents import Document
17
- from collections import defaultdict
18
- import hashlib
19
- from tqdm import tqdm
20
-
21
- from dotenv import load_dotenv
22
- load_dotenv()
23
- # --- Configuration ---
24
- FAISS_INDEX_PATH = "faiss_index"
25
- BM25_INDEX_PATH = "bm25_index.pkl"
26
- CACHE_VERSION = "v1" # Increment when data format changes
27
- embedding_model = "e5-mistral-7b-instruct" # OpenAI embedding model
28
- generation_model = "gemini-2.0-flash" # Gemini generation model
29
- data_file_name = "AskNatureNet_data_enhanced.json"
30
- API_CONFIG = {
31
- "gemini_api_key": os.getenv("GEMINI_API_KEY") # Gemini API key for generation
32
- }
33
-
34
- CHUNK_SIZE = 800
35
- OVERLAP = 200
36
- EMBEDDING_BATCH_SIZE = 32 # Batch size for embedding API calls
37
-
38
- # Initialize clients
39
- OPENAI_API_CONFIG = {
40
- "api_key": os.getenv("OPENAI_API_KEY"),
41
- "base_url": "https://chat-ai.academiccloud.de/v1"
42
- }
43
- client = OpenAI(**OPENAI_API_CONFIG)
44
- gemini_client = genai.Client(api_key=API_CONFIG["gemini_api_key"]) # Gemini client for generation
45
- logging.basicConfig(level=logging.INFO)
46
- logger = logging.getLogger(__name__)
47
-
48
- # --- Helper Functions ---
49
- def get_data_hash(file_path: str) -> str:
50
- """Generate hash of data file for cache validation"""
51
- with open(file_path, "rb") as f:
52
- return hashlib.md5(f.read()).hexdigest()
53
-
54
- # --- Custom Embedding Handler with Progress Tracking ---
55
- class MistralEmbeddings(Embeddings):
56
- """E5-Mistral-7B embedding adapter with error handling and progress tracking"""
57
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
58
- embeddings = []
59
- try:
60
- # Process in batches with progress tracking
61
- for i in tqdm(range(0, len(texts), EMBEDDING_BATCH_SIZE), desc="Embedding Progress"):
62
- batch = texts[i:i + EMBEDDING_BATCH_SIZE]
63
- response = client.embeddings.create(
64
- input=batch,
65
- model=embedding_model,
66
- encoding_format="float"
67
- )
68
- embeddings.extend([e.embedding for e in response.data])
69
- return embeddings
70
- except Exception as e:
71
- logger.error(f"Embedding Error: {str(e)}")
72
- return [[] for _ in texts]
73
-
74
- def embed_query(self, text: str) -> List[float]:
75
- return self.embed_documents([text])[0]
76
-
77
- # --- Data Processing with Cache Validation ---
78
- def load_and_chunk_data(file_path: str) -> List[Document]:
79
- """Enhanced chunking with metadata preservation"""
80
- current_hash = get_data_hash(file_path)
81
- cache_file = f"documents_{CACHE_VERSION}_{current_hash}.pkl"
82
-
83
- if os.path.exists(cache_file):
84
- logger.info("Loading cached documents")
85
- with open(cache_file, "rb") as f:
86
- return pickle.load(f)
87
-
88
- with open(file_path, 'r', encoding='utf-8') as f:
89
- data = json.load(f)
90
-
91
- documents = []
92
- for item in tqdm(data, desc="Chunking Progress"):
93
- base_content = f"""Source: {item['Source']}
94
- Application: {item['Application']}
95
- Functions: {', '.join(filter(None, [item.get('Function1'), item.get('Function2')]))}
96
- Technical Concepts: {', '.join(item['technical_concepts'])}
97
- Biological Mechanisms: {', '.join(item['biological_mechanisms'])}"""
98
-
99
- strategy = item['Strategy']
100
- for i in range(0, len(strategy), CHUNK_SIZE - OVERLAP):
101
- chunk = strategy[i:i + CHUNK_SIZE]
102
- documents.append(Document(
103
- page_content=f"{base_content}\nStrategy Excerpt:\n{chunk}",
104
- metadata={
105
- "source": item["Source"],
106
- "application": item["Application"],
107
- "technical_concepts": item["technical_concepts"],
108
- "sustainability_impacts": item["sustainability_impacts"],
109
- "hyperlink": item["Hyperlink"],
110
- "chunk_id": f"{item['Source']}-{len(documents)+1}"
111
- }
112
- ))
113
-
114
- with open(cache_file, "wb") as f:
115
- pickle.dump(documents, f)
116
- return documents
117
-
118
- # --- Optimized Retrieval System ---
119
- class EnhancedRetriever:
120
- """Hybrid retriever with persistent caching"""
121
- def __init__(self, documents: List[Document]):
122
- self.documents = documents
123
- self.bm25 = self._init_bm25()
124
- self.vector_store = self._init_faiss()
125
- self.vector_retriever = self.vector_store.as_retriever(search_kwargs={"k": 3})
126
-
127
- def _init_bm25(self) -> BM25Retriever:
128
- cache_key = f"{BM25_INDEX_PATH}_{get_data_hash(data_file_name)}"
129
- if os.path.exists(cache_key):
130
- logger.info("Loading cached BM25 index")
131
- with open(cache_key, "rb") as f:
132
- return pickle.load(f)
133
-
134
- logger.info("Building new BM25 index")
135
- retriever = BM25Retriever.from_documents(self.documents)
136
- retriever.k = 5
137
- with open(cache_key, "wb") as f:
138
- pickle.dump(retriever, f)
139
- return retriever
140
-
141
- def _init_faiss(self) -> FAISS:
142
- cache_key = f"{FAISS_INDEX_PATH}_{get_data_hash(data_file_name)}"
143
- if os.path.exists(cache_key):
144
- logger.info("Loading cached FAISS index")
145
- return FAISS.load_local(
146
- cache_key,
147
- MistralEmbeddings(),
148
- allow_dangerous_deserialization=True
149
- )
150
-
151
- logger.info("Building new FAISS index")
152
- vector_store = FAISS.from_documents(self.documents, MistralEmbeddings())
153
- vector_store.save_local(cache_key)
154
- return vector_store
155
-
156
- @lru_cache(maxsize=500)
157
- def retrieve(self, query: str) -> str:
158
- try:
159
- processed_query = self._preprocess_query(query)
160
- expanded_query = self._hyde_expansion(processed_query)
161
-
162
- bm25_results = self.bm25.invoke(processed_query)
163
- vector_results = self.vector_retriever.invoke(processed_query)
164
- expanded_results = self.bm25.invoke(expanded_query)
165
-
166
- fused_results = self._fuse_results([bm25_results, vector_results, expanded_results])
167
- return self._format_context(fused_results[:5])
168
- except Exception as e:
169
- logger.error(f"Retrieval Error: {str(e)}")
170
- return ""
171
-
172
- def _preprocess_query(self, query: str) -> str:
173
- return query.lower().strip()
174
-
175
- @lru_cache(maxsize=500)
176
- def _hyde_expansion(self, query: str) -> str:
177
- try:
178
- response = gemini_client.models.generate_content( # Use Gemini client for HyDE
179
- model=generation_model,
180
- contents=f"Generate a technical draft about biomimicry for: {query}\nInclude domain-specific terms."
181
- )
182
- return response.text
183
- except Exception as e:
184
- logger.error(f"HyDE Error: {str(e)}")
185
- return query
186
-
187
- def _fuse_results(self, result_sets: List[List[Document]]) -> List[Document]:
188
- fused_scores = defaultdict(float)
189
- for docs in result_sets:
190
- for rank, doc in enumerate(docs, 1):
191
- fused_scores[doc.metadata["chunk_id"]] += 1 / (rank + 60)
192
-
193
- seen = set()
194
- return [
195
- doc for doc in sorted(
196
- (doc for docs in result_sets for doc in docs),
197
- key=lambda x: fused_scores[x.metadata["chunk_id"]],
198
- reverse=True
199
- ) if not (doc.metadata["chunk_id"] in seen or seen.add(doc.metadata["chunk_id"]))
200
- ]
201
-
202
- def _format_context(self, docs: List[Document]) -> str:
203
- context = []
204
- for doc in docs:
205
- context_str = f"""**Source**: [{doc.metadata['source']}]({doc.metadata['hyperlink']})
206
- **Application**: {doc.metadata['application']}
207
- **Key Concepts**: {', '.join(doc.metadata['technical_concepts'])}
208
- **Strategy Excerpt**:\n{doc.page_content.split('Strategy Excerpt:')[-1].strip()}"""
209
- context.append(context_str)
210
- return "\n\n---\n\n".join(context)
211
-
212
- # --- Generation System ---
213
- SYSTEM_PROMPT = """**Biomimicry Expert Guidelines**
214
- 1. Base answers strictly on context
215
- 2. **Bold** technical terms
216
- 3. Include reference links at the end of the response
217
-
218
- Context: {context}"""
219
-
220
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=20))
221
- def get_ai_response(query: str, context: str) -> str:
222
- try:
223
- response = gemini_client.models.generate_content( # Use Gemini client for generation
224
- model=generation_model,
225
- contents=f"{SYSTEM_PROMPT.format(context=context)}\nQuestion: {query}\nProvide a detailed technical answer:"
226
- )
227
- logger.info(f"Raw Response: {response.text}") # Log raw response
228
- return _postprocess_response(response.text)
229
- except Exception as e:
230
- logger.error(f"Generation Error: {str(e)}")
231
- return "I'm unable to generate a response right now. Please try again later."
232
-
233
- def _postprocess_response(response: str) -> str:
234
- response = re.sub(r"\[(.*?)\]", r"[\1](#)", response)
235
- response = re.sub(r"\*\*([\w-]+)\*\*", r"**\1**", response)
236
- return response
237
-
238
- # --- Optimized Pipeline ---
239
- documents = load_and_chunk_data(data_file_name)
240
- retriever = EnhancedRetriever(documents)
241
-
242
- def generate_response(question: str) -> str:
243
- try:
244
- context = retriever.retrieve(question)
245
- return get_ai_response(question, context) if context else "No relevant information found."
246
- except Exception as e:
247
- logger.error(f"Pipeline Error: {str(e)}")
248
- return "An error occurred processing your request."
249
-
250
- # --- Gradio Interface ---
251
- def chat_interface(question: str, history: List[Tuple[str, str]]):
252
- response = generate_response(question)
253
- return "", history + [(question, response)]
254
-
255
- with gr.Blocks(title="AskNature BioRAG Expert", theme=gr.themes.Soft()) as demo:
256
- gr.Markdown("# 🌿 AskNature RAG-based Chatbot ")
257
- with gr.Row():
258
- chatbot = gr.Chatbot(label="Dialogue History", height=500)
259
- with gr.Row():
260
- question = gr.Textbox(placeholder="Ask about biomimicry (e.g. 'How does Werewool use coral proteins to make fibers?')",
261
- label="Inquiry", scale=4)
262
- clear_btn = gr.Button("Clear History", variant="secondary")
263
-
264
- gr.Markdown("""
265
- <div style="text-align: center; color: #4a7c59;">
266
- <small>Powered by AskNature's Database |
267
- Explore nature's blueprints at <a href="https://asknature.org">asknature.org</a></small>
268
- </div>""")
269
- question.submit(chat_interface, [question, chatbot], [question, chatbot])
270
- clear_btn.click(lambda: [], None, chatbot)
271
-
272
- if __name__ == "__main__":
273
- =======
274
  # Optimized RAG System with E5-Mistral Embeddings and Gemini Flash Generation
275
 
276
  import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Optimized RAG System with E5-Mistral Embeddings and Gemini Flash Generation
2
 
3
  import json