Mohamed284 commited on
Commit
453f31f
·
verified ·
1 Parent(s): 37a3135

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -136
app.py CHANGED
@@ -1,142 +1,210 @@
1
- import os
2
  import json
3
- import pandas as pd
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
- from langchain_community.vectorstores import FAISS
6
- from langchain_core.prompts import PromptTemplate
7
- from langchain_core.output_parsers import StrOutputParser
8
- from operator import itemgetter
9
  import gradio as gr
10
- from langchain_community.embeddings import HuggingFaceEmbeddings
11
-
12
- # Configuration
13
- USE_HF = True
14
- MODEL_NAME = "stanford-crfm/BioMedLM"
15
- BATCH_SIZE = 8 # Adjusted batch size for memory optimization
16
-
17
- # Load data
18
- with open('AskNatureNet_data.json', 'r', encoding='utf-8') as f:
19
- data = json.load(f)
20
- df = pd.DataFrame(data)
21
- documents = [
22
- f"Source: {item['Source']}\nApplication: {item['Application']}\nFunction1: {item['Function1']}\nStrategy: {item['Strategy']}"
23
- for item in data
24
- ]
25
-
26
- if USE_HF:
27
- print("Using Hugging Face model...")
28
-
29
- huggingface_token = os.environ.get("AskNature_RAG")
30
-
31
- # Quantization configuration for 4-bit precision
32
- bnb_config = BitsAndBytesConfig(
33
- load_in_4bit=True,
34
- bnb_4bit_use_double_quant=True,
35
- bnb_4bit_quant_type="nf4"
36
- )
37
-
38
- # Load tokenizer and model with offloading and quantization
39
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=huggingface_token)
40
- model = AutoModelForCausalLM.from_pretrained(
41
- MODEL_NAME,
42
- device_map="cpu",
43
- offload_folder="offload", # Specify the offload folder
44
- quantization_config=bnb_config,
45
- use_auth_token=huggingface_token
46
- )
47
- embeddings = HuggingFaceEmbeddings(model_name=MODEL_NAME)
48
- lang_model = model
49
- else:
50
- print("Using local model...")
51
- # Local model loading logic here
52
-
53
- # Generate embeddings in batches
54
- batched_embeddings = [
55
- embeddings.embed_documents(documents[i:i + BATCH_SIZE])
56
- for i in range(0, len(documents), BATCH_SIZE)
57
- ]
58
- batched_embeddings = [embed for batch in batched_embeddings for embed in batch]
59
-
60
- # FAISS index handling
61
- index_path = "faiss_index"
62
- if os.path.exists(index_path):
63
- vectorstore = FAISS.load_local(index_path, embeddings)
64
- else:
65
- vectorstore = FAISS.from_texts(documents, embeddings)
66
- vectorstore.save_local(index_path)
67
-
68
- retriever = vectorstore.as_retriever()
69
-
70
- # Prompt template
71
- template = """
72
- Answer the question based on the context below. If you can't
73
- answer the question, reply "I don't know".
74
- Context: {context}
75
- Question: {question}
76
- """
77
- prompt = PromptTemplate.from_template(template)
78
-
79
- # Chain definition
80
- chain = {
81
- "context": itemgetter("question") | retriever,
82
- "question": itemgetter("question"),
83
- } | prompt | lang_model | StrOutputParser()
84
-
85
- # Question-answering function
86
- def rag_qa(question):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  try:
88
- return chain.invoke({'question': question})
 
89
  except Exception as e:
90
- return f"Error: {str(e)}"
91
-
92
- # Gradio chatbot interface
93
- def respond(
94
- message,
95
- history: list[tuple[str, str]],
96
- system_message,
97
- max_tokens,
98
- temperature,
99
- top_p,
100
- ):
101
- messages = [{"role": "system", "content": system_message}]
102
-
103
- for val in history:
104
- if val[0]:
105
- messages.append({"role": "user", "content": val[0]})
106
- if val[1]:
107
- messages.append({"role": "assistant", "content": val[1]})
108
-
109
- messages.append({"role": "user", "content": message})
110
-
111
- response = ""
112
-
113
- for message in client.chat_completion(
114
- messages,
115
- max_tokens=max_tokens,
116
- stream=True,
117
- temperature=temperature,
118
- top_p=top_p,
119
- ):
120
- token = message.choices[0].delta.content
121
- response += token
122
- yield response
123
-
124
- # Gradio interface setup
125
- demo = gr.ChatInterface(
126
- respond,
127
- additional_inputs=[
128
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
129
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
130
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
131
- gr.Slider(
132
- minimum=0.1,
133
- maximum=1.0,
134
- value=0.95,
135
- step=0.05,
136
- label="Top-p (nucleus sampling)",
137
- ),
138
- ],
139
- )
140
 
141
  if __name__ == "__main__":
142
- demo.launch()
 
1
+ # Optimized RAG System with E5-Mistral Embeddings and Llama3-70B Generation
2
  import json
3
+ import logging
4
+ import re
5
+ from typing import List, Tuple
 
 
 
6
  import gradio as gr
7
+ from openai import OpenAI
8
+ from functools import lru_cache
9
+ from tenacity import retry, stop_after_attempt, wait_exponential
10
+ from langchain_community.retrievers import BM25Retriever
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain_core.embeddings import Embeddings
13
+ from langchain_core.documents import Document
14
+ from collections import defaultdict
15
+
16
+ embedding_model = "e5-mistral-7b-instruct"
17
+ generation_model = "meta-llama-3-70b-instruct"
18
+ # --- Configuration ---
19
+ API_CONFIG = {
20
+ "api_key": "d9960fad1d2aaa16167902b0d26e369f",
21
+ "base_url": "https://chat-ai.academiccloud.de/v1"
22
+ }
23
+ CHUNK_SIZE = 800
24
+ OVERLAP = 200
25
+
26
+ # Initialize clients
27
+ client = OpenAI(**API_CONFIG)
28
+ logging.basicConfig(level=logging.INFO)
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # --- Custom Embedding Handler ---
32
+ class MistralEmbeddings(Embeddings):
33
+ """E5-Mistral-7B embedding adapter with error handling"""
34
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
35
+ try:
36
+ response = client.embeddings.create(
37
+ input=texts,
38
+ model=embedding_model,
39
+ encoding_format="float"
40
+ )
41
+ return [e.embedding for e in response.data]
42
+ except Exception as e:
43
+ logger.error(f"Embedding Error: {str(e)}")
44
+ return [[] for _ in texts] # Return empty embeddings on failure
45
+
46
+ def embed_query(self, text: str) -> List[float]:
47
+ return self.embed_documents([text])[0]
48
+
49
+ # --- Data Processing ---
50
+ def load_and_chunk_data(file_path: str) -> List[Document]:
51
+ """Enhanced chunking with metadata preservation"""
52
+ with open(file_path, 'r', encoding='utf-8') as f:
53
+ data = json.load(f)
54
+
55
+ documents = []
56
+ for item in data:
57
+ base_content = f"""Source: {item['Source']}
58
+ Application: {item['Application']}
59
+ Functions: {', '.join(filter(None, [item.get('Function1'), item.get('Function2')]))}
60
+ Technical Concepts: {', '.join(item['technical_concepts'])}
61
+ Biological Mechanisms: {', '.join(item['biological_mechanisms'])}"""
62
+
63
+ strategy = item['Strategy']
64
+ for i in range(0, len(strategy), CHUNK_SIZE - OVERLAP):
65
+ chunk = strategy[i:i + CHUNK_SIZE]
66
+ documents.append(Document(
67
+ page_content=f"{base_content}\nStrategy Excerpt:\n{chunk}",
68
+ metadata={
69
+ "source": item["Source"],
70
+ "application": item["Application"],
71
+ "technical_concepts": item["technical_concepts"],
72
+ "sustainability_impacts": item["sustainability_impacts"],
73
+ "hyperlink": item["Hyperlink"],
74
+ "chunk_id": f"{item['Source']}-{len(documents)+1}"
75
+ }
76
+ ))
77
+ return documents
78
+
79
+ # --- Hybrid Retrieval System ---
80
+ class EnhancedRetriever:
81
+ """BM25 + E5-Mistral embeddings with fusion"""
82
+ def __init__(self, documents: List[Document]):
83
+ self.bm25 = BM25Retriever.from_documents(documents)
84
+ self.bm25.k = 5
85
+ self.vector_store = FAISS.from_documents(documents, MistralEmbeddings())
86
+ self.vector_retriever = self.vector_store.as_retriever(search_kwargs={"k": 3})
87
+
88
+ @lru_cache(maxsize=200)
89
+ def retrieve(self, query: str) -> str:
90
+ try:
91
+ processed_query = self._preprocess_query(query)
92
+ expanded_query = self._hyde_expansion(processed_query)
93
+
94
+ bm25_results = self.bm25.invoke(processed_query)
95
+ vector_results = self.vector_retriever.invoke(processed_query)
96
+ expanded_results = self.bm25.invoke(expanded_query)
97
+
98
+ fused_results = self._fuse_results([bm25_results, vector_results, expanded_results])
99
+ return self._format_context(fused_results[:5])
100
+ except Exception as e:
101
+ logger.error(f"Retrieval Error: {str(e)}")
102
+ return ""
103
+
104
+ def _preprocess_query(self, query: str) -> str:
105
+ return query.lower().strip()
106
+
107
+ def _hyde_expansion(self, query: str) -> str:
108
+ try:
109
+ response = client.chat.completions.create(
110
+ model=generation_model,
111
+ messages=[{
112
+ "role": "user",
113
+ "content": f"Generate a technical draft about biomimicry for: {query}\nInclude domain-specific terms."
114
+ }],
115
+ temperature=0.5,
116
+ max_tokens=200
117
+ )
118
+ return response.choices[0].message.content
119
+ except Exception as e:
120
+ logger.error(f"HyDE Error: {str(e)}")
121
+ return query
122
+
123
+ def _fuse_results(self, result_sets: List[List[Document]]) -> List[Document]:
124
+ fused_scores = defaultdict(float)
125
+ for docs in result_sets:
126
+ for rank, doc in enumerate(docs, 1):
127
+ fused_scores[doc.metadata["chunk_id"]] += 1 / (rank + 60)
128
+
129
+ seen = set()
130
+ return [
131
+ doc for doc in sorted(
132
+ (doc for docs in result_sets for doc in docs),
133
+ key=lambda x: fused_scores[x.metadata["chunk_id"]],
134
+ reverse=True
135
+ ) if not (doc.metadata["chunk_id"] in seen or seen.add(doc.metadata["chunk_id"]))
136
+ ]
137
+
138
+ def _format_context(self, docs: List[Document]) -> str:
139
+ context = []
140
+ for doc in docs:
141
+ context_str = f"""**Source**: {doc.metadata['source']}
142
+ **Application**: {doc.metadata['application']}
143
+ **Concepts**: {', '.join(doc.metadata['technical_concepts'])}
144
+ **Excerpt**: {doc.page_content.split('Strategy Excerpt:')[-1].strip()}
145
+ **Reference**: {doc.metadata['hyperlink']}"""
146
+ context.append(context_str)
147
+ return "\n\n---\n\n".join(context)
148
+
149
+ # --- Generation System ---
150
+ SYSTEM_PROMPT = """**Biomimicry Expert Guidelines**
151
+ 1. Base answers strictly on context
152
+ 2. Cite sources as [Source]
153
+ 3. **Bold** technical terms
154
+ 4. Include reference links
155
+
156
+ Context: {context}"""
157
+
158
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=20))
159
+ def get_ai_response(query: str, context: str) -> str:
160
+ try:
161
+ response = client.chat.completions.create(
162
+ model=generation_model,
163
+ messages=[
164
+ {"role": "system", "content": SYSTEM_PROMPT.format(context=context)},
165
+ {"role": "user", "content": f"Question: {query}\nProvide a detailed technical answer:"}
166
+ ],
167
+ temperature=0.4,
168
+ max_tokens=600
169
+ )
170
+ return _postprocess_response(response.choices[0].message.content)
171
+ except Exception as e:
172
+ logger.error(f"Generation Error: {str(e)}")
173
+ return "I'm unable to generate a response right now. Please try again later."
174
+
175
+ def _postprocess_response(response: str) -> str:
176
+ response = re.sub(r"\[(.*?)\]", r"[\1](#)", response)
177
+ response = re.sub(r"\*\*([\w-]+)\*\*", r"**\1**", response)
178
+ return response
179
+
180
+ # --- Pipeline Integration ---
181
+ documents = load_and_chunk_data("mini_data_enhanced.json")
182
+ retriever = EnhancedRetriever(documents)
183
+
184
+ def generate_response(question: str) -> str:
185
  try:
186
+ context = retriever.retrieve(question)
187
+ return get_ai_response(question, context) if context else "No relevant information found."
188
  except Exception as e:
189
+ logger.error(f"Pipeline Error: {str(e)}")
190
+ return "An error occurred processing your request."
191
+
192
+ # --- Gradio Interface ---
193
+ def chat_interface(question: str, history: List[Tuple[str, str]]):
194
+ response = generate_response(question)
195
+ return "", history + [(question, response)]
196
+
197
+ with gr.Blocks(title="BioRAG Expert", theme=gr.themes.Soft()) as demo:
198
+ gr.Markdown("# 🌿 BioRAG: Biomimicry Technical Assistant")
199
+ with gr.Row():
200
+ chatbot = gr.Chatbot(label="Dialogue History", height=500)
201
+ with gr.Row():
202
+ question = gr.Textbox(placeholder="Ask about nature-inspired innovations...",
203
+ label="Technical Inquiry", scale=4)
204
+ clear_btn = gr.Button("Clear History", variant="secondary")
205
+
206
+ question.submit(chat_interface, [question, chatbot], [question, chatbot])
207
+ clear_btn.click(lambda: [], None, chatbot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  if __name__ == "__main__":
210
+ demo.launch(show_error=True)