Mohamed284 commited on
Commit
c1f8517
·
verified ·
1 Parent(s): 5202779

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -25
app.py CHANGED
@@ -2,7 +2,9 @@
2
  import json
3
  import logging
4
  import re
5
- from typing import List, Tuple
 
 
6
  import gradio as gr
7
  from openai import OpenAI
8
  from functools import lru_cache
@@ -12,51 +14,74 @@ from langchain_community.vectorstores import FAISS
12
  from langchain_core.embeddings import Embeddings
13
  from langchain_core.documents import Document
14
  from collections import defaultdict
15
- import os
16
-
17
-
18
 
 
 
 
 
19
  embedding_model = "e5-mistral-7b-instruct"
20
  generation_model = "meta-llama-3-70b-instruct"
21
- # --- Configuration ---
22
  API_CONFIG = {
23
  "api_key": os.getenv("API_KEY"),
24
  "base_url": "https://chat-ai.academiccloud.de/v1"
25
  }
26
  CHUNK_SIZE = 800
27
  OVERLAP = 200
 
28
 
29
  # Initialize clients
30
  client = OpenAI(**API_CONFIG)
31
  logging.basicConfig(level=logging.INFO)
32
  logger = logging.getLogger(__name__)
33
 
34
- # --- Custom Embedding Handler ---
 
 
 
 
 
 
35
  class MistralEmbeddings(Embeddings):
36
- """E5-Mistral-7B embedding adapter with error handling"""
37
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
 
38
  try:
39
- response = client.embeddings.create(
40
- input=texts,
41
- model=embedding_model,
42
- encoding_format="float"
43
- )
44
- return [e.embedding for e in response.data]
 
 
 
 
45
  except Exception as e:
46
  logger.error(f"Embedding Error: {str(e)}")
47
- return [[] for _ in texts] # Return empty embeddings on failure
48
 
49
  def embed_query(self, text: str) -> List[float]:
50
  return self.embed_documents([text])[0]
51
 
52
- # --- Data Processing ---
53
  def load_and_chunk_data(file_path: str) -> List[Document]:
54
  """Enhanced chunking with metadata preservation"""
 
 
 
 
 
 
 
 
55
  with open(file_path, 'r', encoding='utf-8') as f:
56
  data = json.load(f)
57
 
58
  documents = []
59
- for item in data:
60
  base_content = f"""Source: {item['Source']}
61
  Application: {item['Application']}
62
  Functions: {', '.join(filter(None, [item.get('Function1'), item.get('Function2')]))}
@@ -77,18 +102,50 @@ Biological Mechanisms: {', '.join(item['biological_mechanisms'])}"""
77
  "chunk_id": f"{item['Source']}-{len(documents)+1}"
78
  }
79
  ))
 
 
 
80
  return documents
81
 
82
- # --- Hybrid Retrieval System ---
83
  class EnhancedRetriever:
84
- """BM25 + E5-Mistral embeddings with fusion"""
85
  def __init__(self, documents: List[Document]):
86
- self.bm25 = BM25Retriever.from_documents(documents)
87
- self.bm25.k = 5
88
- self.vector_store = FAISS.from_documents(documents, MistralEmbeddings())
89
  self.vector_retriever = self.vector_store.as_retriever(search_kwargs={"k": 3})
90
 
91
- @lru_cache(maxsize=200)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def retrieve(self, query: str) -> str:
93
  try:
94
  processed_query = self._preprocess_query(query)
@@ -107,6 +164,7 @@ class EnhancedRetriever:
107
  def _preprocess_query(self, query: str) -> str:
108
  return query.lower().strip()
109
 
 
110
  def _hyde_expansion(self, query: str) -> str:
111
  try:
112
  response = client.chat.completions.create(
@@ -167,8 +225,9 @@ def get_ai_response(query: str, context: str) -> str:
167
  {"role": "user", "content": f"Question: {query}\nProvide a detailed technical answer:"}
168
  ],
169
  temperature=0.4,
170
- max_tokens=600
171
  )
 
172
  return _postprocess_response(response.choices[0].message.content)
173
  except Exception as e:
174
  logger.error(f"Generation Error: {str(e)}")
@@ -179,8 +238,8 @@ def _postprocess_response(response: str) -> str:
179
  response = re.sub(r"\*\*([\w-]+)\*\*", r"**\1**", response)
180
  return response
181
 
182
- # --- Pipeline Integration ---
183
- documents = load_and_chunk_data("mini_data_enhanced.json")
184
  retriever = EnhancedRetriever(documents)
185
 
186
  def generate_response(question: str) -> str:
 
2
  import json
3
  import logging
4
  import re
5
+ import os
6
+ import pickle
7
+ from typing import List, Tuple, Optional
8
  import gradio as gr
9
  from openai import OpenAI
10
  from functools import lru_cache
 
14
  from langchain_core.embeddings import Embeddings
15
  from langchain_core.documents import Document
16
  from collections import defaultdict
17
+ import hashlib
18
+ from tqdm import tqdm
 
19
 
20
+ # --- Configuration ---
21
+ FAISS_INDEX_PATH = "faiss_index"
22
+ BM25_INDEX_PATH = "bm25_index.pkl"
23
+ CACHE_VERSION = "v1" # Increment when data format changes
24
  embedding_model = "e5-mistral-7b-instruct"
25
  generation_model = "meta-llama-3-70b-instruct"
26
+ data_file_name = "AskNatureNet_data_enhanced.json"
27
  API_CONFIG = {
28
  "api_key": os.getenv("API_KEY"),
29
  "base_url": "https://chat-ai.academiccloud.de/v1"
30
  }
31
  CHUNK_SIZE = 800
32
  OVERLAP = 200
33
+ EMBEDDING_BATCH_SIZE = 32 # Batch size for embedding API calls
34
 
35
  # Initialize clients
36
  client = OpenAI(**API_CONFIG)
37
  logging.basicConfig(level=logging.INFO)
38
  logger = logging.getLogger(__name__)
39
 
40
+ # --- Helper Functions ---
41
+ def get_data_hash(file_path: str) -> str:
42
+ """Generate hash of data file for cache validation"""
43
+ with open(file_path, "rb") as f:
44
+ return hashlib.md5(f.read()).hexdigest()
45
+
46
+ # --- Custom Embedding Handler with Progress Tracking ---
47
  class MistralEmbeddings(Embeddings):
48
+ """E5-Mistral-7B embedding adapter with error handling and progress tracking"""
49
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
50
+ embeddings = []
51
  try:
52
+ # Process in batches with progress tracking
53
+ for i in tqdm(range(0, len(texts), EMBEDDING_BATCH_SIZE), desc="Embedding Progress"):
54
+ batch = texts[i:i + EMBEDDING_BATCH_SIZE]
55
+ response = client.embeddings.create(
56
+ input=batch,
57
+ model=embedding_model,
58
+ encoding_format="float"
59
+ )
60
+ embeddings.extend([e.embedding for e in response.data])
61
+ return embeddings
62
  except Exception as e:
63
  logger.error(f"Embedding Error: {str(e)}")
64
+ return [[] for _ in texts]
65
 
66
  def embed_query(self, text: str) -> List[float]:
67
  return self.embed_documents([text])[0]
68
 
69
+ # --- Data Processing with Cache Validation ---
70
  def load_and_chunk_data(file_path: str) -> List[Document]:
71
  """Enhanced chunking with metadata preservation"""
72
+ current_hash = get_data_hash(file_path)
73
+ cache_file = f"documents_{CACHE_VERSION}_{current_hash}.pkl"
74
+
75
+ if os.path.exists(cache_file):
76
+ logger.info("Loading cached documents")
77
+ with open(cache_file, "rb") as f:
78
+ return pickle.load(f)
79
+
80
  with open(file_path, 'r', encoding='utf-8') as f:
81
  data = json.load(f)
82
 
83
  documents = []
84
+ for item in tqdm(data, desc="Chunking Progress"):
85
  base_content = f"""Source: {item['Source']}
86
  Application: {item['Application']}
87
  Functions: {', '.join(filter(None, [item.get('Function1'), item.get('Function2')]))}
 
102
  "chunk_id": f"{item['Source']}-{len(documents)+1}"
103
  }
104
  ))
105
+
106
+ with open(cache_file, "wb") as f:
107
+ pickle.dump(documents, f)
108
  return documents
109
 
110
+ # --- Optimized Retrieval System ---
111
  class EnhancedRetriever:
112
+ """Hybrid retriever with persistent caching"""
113
  def __init__(self, documents: List[Document]):
114
+ self.documents = documents
115
+ self.bm25 = self._init_bm25()
116
+ self.vector_store = self._init_faiss()
117
  self.vector_retriever = self.vector_store.as_retriever(search_kwargs={"k": 3})
118
 
119
+ def _init_bm25(self) -> BM25Retriever:
120
+ cache_key = f"{BM25_INDEX_PATH}_{get_data_hash(data_file_name)}"
121
+ if os.path.exists(cache_key):
122
+ logger.info("Loading cached BM25 index")
123
+ with open(cache_key, "rb") as f:
124
+ return pickle.load(f)
125
+
126
+ logger.info("Building new BM25 index")
127
+ retriever = BM25Retriever.from_documents(self.documents)
128
+ retriever.k = 5
129
+ with open(cache_key, "wb") as f:
130
+ pickle.dump(retriever, f)
131
+ return retriever
132
+
133
+ def _init_faiss(self) -> FAISS:
134
+ cache_key = f"{FAISS_INDEX_PATH}_{get_data_hash(data_file_name)}"
135
+ if os.path.exists(cache_key):
136
+ logger.info("Loading cached FAISS index")
137
+ return FAISS.load_local(
138
+ cache_key,
139
+ MistralEmbeddings(),
140
+ allow_dangerous_deserialization=True
141
+ )
142
+
143
+ logger.info("Building new FAISS index")
144
+ vector_store = FAISS.from_documents(self.documents, MistralEmbeddings())
145
+ vector_store.save_local(cache_key)
146
+ return vector_store
147
+
148
+ @lru_cache(maxsize=500)
149
  def retrieve(self, query: str) -> str:
150
  try:
151
  processed_query = self._preprocess_query(query)
 
164
  def _preprocess_query(self, query: str) -> str:
165
  return query.lower().strip()
166
 
167
+ @lru_cache(maxsize=500)
168
  def _hyde_expansion(self, query: str) -> str:
169
  try:
170
  response = client.chat.completions.create(
 
225
  {"role": "user", "content": f"Question: {query}\nProvide a detailed technical answer:"}
226
  ],
227
  temperature=0.4,
228
+ max_tokens=2000 # Increased max_tokens
229
  )
230
+ logger.info(f"Raw Response: {response.choices[0].message.content}") # Log raw response
231
  return _postprocess_response(response.choices[0].message.content)
232
  except Exception as e:
233
  logger.error(f"Generation Error: {str(e)}")
 
238
  response = re.sub(r"\*\*([\w-]+)\*\*", r"**\1**", response)
239
  return response
240
 
241
+ # --- Optimized Pipeline ---
242
+ documents = load_and_chunk_data(data_file_name)
243
  retriever = EnhancedRetriever(documents)
244
 
245
  def generate_response(question: str) -> str: