Update app.py
Browse files
app.py
CHANGED
|
@@ -1,41 +1,7 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import time
|
| 3 |
-
import gradio as gr
|
| 4 |
-
from dotenv import load_dotenv
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
import re
|
| 7 |
-
import json
|
| 8 |
-
|
| 9 |
-
# Import Document from your LangChain module.
|
| 10 |
-
# (Adjust the import if your version of LangChain uses a different path.)
|
| 11 |
-
from langchain_core.documents import Document
|
| 12 |
-
|
| 13 |
-
# Import additional libraries from LangChain
|
| 14 |
-
from langchain_chroma import Chroma
|
| 15 |
-
from langchain_openai import OpenAIEmbeddings
|
| 16 |
-
from langchain_community.retrievers import BM25Retriever
|
| 17 |
-
from langchain.retrievers import EnsembleRetriever
|
| 18 |
-
from langchain_core.runnables import RunnablePassthrough
|
| 19 |
-
from langchain_core.output_parsers import StrOutputParser
|
| 20 |
-
from langchain_openai import ChatOpenAI
|
| 21 |
-
from langchain.chains import create_retrieval_chain
|
| 22 |
-
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 23 |
-
from langchain_core.prompts import ChatPromptTemplate
|
| 24 |
-
|
| 25 |
-
# Load environment variables for Hugging Face and OpenAI
|
| 26 |
-
load_dotenv()
|
| 27 |
-
os.environ['LANGCHAIN_API_KEY'] = os.getenv('LANGCHAIN_API_KEY')
|
| 28 |
-
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY')
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# -------------------------------
|
| 32 |
-
# Utility Functions
|
| 33 |
-
# -------------------------------
|
| 34 |
-
|
| 35 |
import re
|
| 36 |
import json
|
| 37 |
from pathlib import Path
|
| 38 |
-
#
|
| 39 |
from langchain_core.documents import Document
|
| 40 |
|
| 41 |
def extract_metadata(text: str) -> tuple[dict, str]:
|
|
@@ -50,7 +16,6 @@ def extract_metadata(text: str) -> tuple[dict, str]:
|
|
| 50 |
)
|
| 51 |
if title_match:
|
| 52 |
metadata["title"] = title_match.group(1).strip()
|
| 53 |
-
# Remove Title from cleaned_text
|
| 54 |
cleaned_text = re.sub(
|
| 55 |
r"Title:\s*.*?(?=Website:|Twitter:|Instagram:|FaceBook:|Newsletter:)",
|
| 56 |
"",
|
|
@@ -68,7 +33,6 @@ def extract_metadata(text: str) -> tuple[dict, str]:
|
|
| 68 |
ranking_value = ranking_match.group(1).strip()
|
| 69 |
if ranking_value.lower() == "winner":
|
| 70 |
metadata["ranking"] = ranking_value
|
| 71 |
-
# Remove Ranking from cleaned_text
|
| 72 |
cleaned_text = re.sub(
|
| 73 |
r"Ranking:\s*.*?(?=Impact Metrics:|$)",
|
| 74 |
"",
|
|
@@ -80,7 +44,6 @@ def extract_metadata(text: str) -> tuple[dict, str]:
|
|
| 80 |
year_match = re.search(r"Year:\s*(\d{4})", cleaned_text, re.IGNORECASE)
|
| 81 |
if year_match:
|
| 82 |
metadata["year"] = year_match.group(1).strip()
|
| 83 |
-
# Remove Year from cleaned_text
|
| 84 |
cleaned_text = re.sub(r"Year:\s*\d{4}", "", cleaned_text, flags=re.IGNORECASE)
|
| 85 |
|
| 86 |
# Extract and remove Organization
|
|
@@ -91,7 +54,6 @@ def extract_metadata(text: str) -> tuple[dict, str]:
|
|
| 91 |
)
|
| 92 |
if org_match:
|
| 93 |
metadata["organization"] = org_match.group(1).strip()
|
| 94 |
-
# Remove Organization from cleaned_text
|
| 95 |
cleaned_text = re.sub(
|
| 96 |
r"Organization:\s*.*?(?=Goal:|Ranking:|Impact Metrics:)",
|
| 97 |
"",
|
|
@@ -103,7 +65,6 @@ def extract_metadata(text: str) -> tuple[dict, str]:
|
|
| 103 |
urls = re.findall(r"(Website|Volunteer|Newsletter):\s*((?:https?://)?\S+)", cleaned_text)
|
| 104 |
for key, url in urls:
|
| 105 |
metadata[key.lower()] = url.strip()
|
| 106 |
-
# Remove URL from cleaned_text
|
| 107 |
cleaned_text = re.sub(
|
| 108 |
rf"{key}:\s*{re.escape(url)}",
|
| 109 |
"",
|
|
@@ -111,14 +72,13 @@ def extract_metadata(text: str) -> tuple[dict, str]:
|
|
| 111 |
flags=re.IGNORECASE
|
| 112 |
)
|
| 113 |
|
| 114 |
-
# Extract and remove social handles
|
| 115 |
social = re.findall(r"(Twitter|Instagram|FaceBook):\s*(\S+)", cleaned_text)
|
| 116 |
for platform, handle in social:
|
| 117 |
if handle.startswith("http"):
|
| 118 |
metadata[platform.lower()] = handle.strip()
|
| 119 |
else:
|
| 120 |
metadata[f"{platform.lower()}_handle"] = f"https://{platform.lower()}.com/{handle.strip()}"
|
| 121 |
-
# Remove social handle from cleaned_text
|
| 122 |
cleaned_text = re.sub(
|
| 123 |
rf"{platform}:\s*{re.escape(handle)}",
|
| 124 |
"",
|
|
@@ -126,13 +86,54 @@ def extract_metadata(text: str) -> tuple[dict, str]:
|
|
| 126 |
flags=re.IGNORECASE
|
| 127 |
)
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
# Clean up extra whitespace
|
| 130 |
cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
|
| 131 |
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
|
| 135 |
def load_and_process_data(file_path: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
try:
|
| 137 |
data = json.loads(Path(file_path).read_text(encoding='utf-8'))
|
| 138 |
docs = []
|
|
@@ -140,15 +141,16 @@ def load_and_process_data(file_path: str):
|
|
| 140 |
org_text = entry.get("OrganizationText", "")
|
| 141 |
if not org_text:
|
| 142 |
continue
|
| 143 |
-
metadata,
|
| 144 |
if metadata.get("ranking", "").lower() == "winner":
|
| 145 |
-
docs.insert(0, Document(page_content=
|
| 146 |
else:
|
| 147 |
-
docs.append(Document(page_content=
|
| 148 |
return docs
|
| 149 |
except Exception as e:
|
| 150 |
print(f"Error loading JSON: {e}")
|
| 151 |
return []
|
|
|
|
| 152 |
# -------------------------------
|
| 153 |
# Data Loading and Preprocessing
|
| 154 |
# -------------------------------
|
|
@@ -157,11 +159,10 @@ file_path = './data.json' # Ensure this file is available in your environment.
|
|
| 157 |
docs = load_and_process_data(file_path)
|
| 158 |
|
| 159 |
# Use a text splitter to create chunks from the documents.
|
| 160 |
-
# (If you find that key fields are getting split, consider implementing a custom splitter.)
|
| 161 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 162 |
text_splitter = RecursiveCharacterTextSplitter(
|
| 163 |
-
chunk_size=
|
| 164 |
-
chunk_overlap=
|
| 165 |
add_start_index=True
|
| 166 |
)
|
| 167 |
all_splits = text_splitter.split_documents(docs)
|
|
@@ -170,7 +171,11 @@ all_splits = text_splitter.split_documents(docs)
|
|
| 170 |
# Set Up Retrievers
|
| 171 |
# -------------------------------
|
| 172 |
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
persist_directory = "./chroma_db"
|
| 175 |
if os.path.exists(persist_directory) and os.listdir(persist_directory):
|
| 176 |
vectorstore = Chroma(
|
|
@@ -186,13 +191,10 @@ else:
|
|
| 186 |
)
|
| 187 |
print("Created new vector store and persisted embeddings.")
|
| 188 |
|
| 189 |
-
# Create a BM25 retriever from the document splits.
|
| 190 |
bm25_retriever = BM25Retriever.from_documents(all_splits)
|
| 191 |
-
|
| 192 |
-
# Combine the retrievers using an ensemble approach.
|
| 193 |
ensemble_retriever = EnsembleRetriever(
|
| 194 |
retrievers=[vectorstore.as_retriever(search_kwargs={"k": 6}), bm25_retriever],
|
| 195 |
-
weights=[0.
|
| 196 |
)
|
| 197 |
retriever = ensemble_retriever
|
| 198 |
|
|
@@ -200,31 +202,24 @@ retriever = ensemble_retriever
|
|
| 200 |
# Prepare Retrieval and Generation Chain
|
| 201 |
# -------------------------------
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
| 205 |
|
|
|
|
| 206 |
"You are the LA2050 Navigator, an AI-powered chatbot designed to help users explore organizations and community initiatives within the Goldhirsh Foundation’s LA2050 Ideas Hub. "
|
| 207 |
-
|
| 208 |
"Your role is to provide concise, personalized recommendations, guide users toward supporting these organizations and initiatives, and answer relevant questions about the Goldhirsh Foundation, LA2050, and its projects. "
|
| 209 |
-
|
| 210 |
"When answering, include the full name of the organization, a brief (1-2 sentence) description, and a link to its website or social media (as provided under the website column; please do not alter or normalize the URL). "
|
| 211 |
-
|
| 212 |
"If a company's personal website is unavailable, navigate to the LA2050 URLs. "
|
| 213 |
-
|
| 214 |
"Prioritize nonprofit organizations awarded by the Goldhirsh Foundation (designated 'Winner' under ranking column) and those with multiple proposal submissions. "
|
| 215 |
-
|
| 216 |
"Use the data files as your primary source of information. If information is unavailable, acknowledge it and guide the user to relevant resources. "
|
| 217 |
-
|
| 218 |
"Maintain a polite, helpful, respectful, and enthusiastic tone at all times. "
|
| 219 |
-
|
| 220 |
"If the user responds with a follow-up confirmation (e.g. 'yes') after a previous answer, please expand on that topic with additional information. "
|
| 221 |
-
"When answering questions about grant winners, only list organizations whose metadata ranking field is marked as 'Winner'"
|
| 222 |
-
|
| 223 |
"\n\n{context}"
|
| 224 |
-
|
| 225 |
)
|
| 226 |
|
| 227 |
-
|
| 228 |
prompt = ChatPromptTemplate.from_messages(
|
| 229 |
[
|
| 230 |
("system", system_prompt),
|
|
@@ -232,7 +227,6 @@ prompt = ChatPromptTemplate.from_messages(
|
|
| 232 |
]
|
| 233 |
)
|
| 234 |
|
| 235 |
-
# Build the chain that will combine documents with the prompt.
|
| 236 |
question_answer_chain = create_stuff_documents_chain(ChatOpenAI(model_name="gpt-4o-mini", temperature=0), prompt)
|
| 237 |
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
|
| 238 |
|
|
@@ -240,6 +234,11 @@ rag_chain = create_retrieval_chain(retriever, question_answer_chain)
|
|
| 240 |
# Gradio Interface and Conversation Handling
|
| 241 |
# -------------------------------
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
green_theme = gr.themes.Base(
|
| 244 |
primary_hue=gr.themes.Color(
|
| 245 |
c50="#00A168",
|
|
@@ -272,24 +271,19 @@ green_theme = gr.themes.Base(
|
|
| 272 |
)
|
| 273 |
|
| 274 |
def message_and_history(message, history):
|
| 275 |
-
# Initialize conversation with a welcome message if history is empty.
|
| 276 |
history = history or [{"role": "assistant", "content": "<b>LA2050 Navigator:</b><br> Welcome to the LA2050 ideas hub! How can I help you today?"}]
|
| 277 |
user_text = message.get("text", "")
|
| 278 |
history.append({"role": "user", "content": user_text})
|
| 279 |
|
| 280 |
time.sleep(1)
|
| 281 |
-
|
| 282 |
-
# If the user did not provide any input, ask for a valid message.
|
| 283 |
if not user_text:
|
| 284 |
history.append({"role": "assistant", "content": "<b>LA2050 Navigator:</b><br> Please enter a valid message."})
|
| 285 |
yield history, history
|
| 286 |
return
|
| 287 |
|
| 288 |
-
# Combine the most recent conversation turns, excluding the assistant's prefix.
|
| 289 |
conversation_context = "\n".join(
|
| 290 |
[f"{msg['role']}: {msg['content'].replace('<b>LA2050 Navigator:</b><br>', '')}" for msg in history[-3:]]
|
| 291 |
)
|
| 292 |
-
|
| 293 |
chain_input = {"input": conversation_context}
|
| 294 |
|
| 295 |
try:
|
|
@@ -297,25 +291,20 @@ def message_and_history(message, history):
|
|
| 297 |
answer = response["answer"]
|
| 298 |
except Exception as e:
|
| 299 |
answer = f"An error occurred: {e}"
|
| 300 |
-
|
| 301 |
-
# Remove the prefix if the model includes it.
|
| 302 |
if answer.startswith("<b>LA2050 Navigator:</b><br>"):
|
| 303 |
answer = answer[len("<b>LA2050 Navigator:</b><br>"):]
|
| 304 |
-
|
| 305 |
-
# Initialize the assistant's response with the prefix.
|
| 306 |
assistant_response = {"role": "assistant", "content": "<b>LA2050 Navigator:</b><br> "}
|
| 307 |
history.append(assistant_response)
|
| 308 |
-
|
| 309 |
-
# Stream the answer character by character.
|
| 310 |
for character in answer:
|
| 311 |
assistant_response["content"] += character
|
| 312 |
yield history, history
|
| 313 |
-
|
| 314 |
-
# Finalize the answer without re-adding the prefix.
|
| 315 |
history[-1]["content"] = assistant_response["content"]
|
| 316 |
yield history, history
|
| 317 |
|
| 318 |
-
# Set Gradio to light mode via JavaScript
|
| 319 |
js_func = """
|
| 320 |
function refresh() {
|
| 321 |
const url = new URL(window.location);
|
|
@@ -358,7 +347,6 @@ with gr.Blocks(theme=green_theme, js=js_func, css=css) as block:
|
|
| 358 |
show_label=False
|
| 359 |
)
|
| 360 |
|
| 361 |
-
# When a message is submitted, the function now sends the recent conversation history along with the new input.
|
| 362 |
message.submit(
|
| 363 |
message_and_history,
|
| 364 |
inputs=[message, state],
|
|
@@ -368,3 +356,5 @@ with gr.Blocks(theme=green_theme, js=js_func, css=css) as block:
|
|
| 368 |
)
|
| 369 |
|
| 370 |
block.launch(debug=True, share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import re
|
| 2 |
import json
|
| 3 |
from pathlib import Path
|
| 4 |
+
# Import your Document class from your LangChain module.
|
| 5 |
from langchain_core.documents import Document
|
| 6 |
|
| 7 |
def extract_metadata(text: str) -> tuple[dict, str]:
|
|
|
|
| 16 |
)
|
| 17 |
if title_match:
|
| 18 |
metadata["title"] = title_match.group(1).strip()
|
|
|
|
| 19 |
cleaned_text = re.sub(
|
| 20 |
r"Title:\s*.*?(?=Website:|Twitter:|Instagram:|FaceBook:|Newsletter:)",
|
| 21 |
"",
|
|
|
|
| 33 |
ranking_value = ranking_match.group(1).strip()
|
| 34 |
if ranking_value.lower() == "winner":
|
| 35 |
metadata["ranking"] = ranking_value
|
|
|
|
| 36 |
cleaned_text = re.sub(
|
| 37 |
r"Ranking:\s*.*?(?=Impact Metrics:|$)",
|
| 38 |
"",
|
|
|
|
| 44 |
year_match = re.search(r"Year:\s*(\d{4})", cleaned_text, re.IGNORECASE)
|
| 45 |
if year_match:
|
| 46 |
metadata["year"] = year_match.group(1).strip()
|
|
|
|
| 47 |
cleaned_text = re.sub(r"Year:\s*\d{4}", "", cleaned_text, flags=re.IGNORECASE)
|
| 48 |
|
| 49 |
# Extract and remove Organization
|
|
|
|
| 54 |
)
|
| 55 |
if org_match:
|
| 56 |
metadata["organization"] = org_match.group(1).strip()
|
|
|
|
| 57 |
cleaned_text = re.sub(
|
| 58 |
r"Organization:\s*.*?(?=Goal:|Ranking:|Impact Metrics:)",
|
| 59 |
"",
|
|
|
|
| 65 |
urls = re.findall(r"(Website|Volunteer|Newsletter):\s*((?:https?://)?\S+)", cleaned_text)
|
| 66 |
for key, url in urls:
|
| 67 |
metadata[key.lower()] = url.strip()
|
|
|
|
| 68 |
cleaned_text = re.sub(
|
| 69 |
rf"{key}:\s*{re.escape(url)}",
|
| 70 |
"",
|
|
|
|
| 72 |
flags=re.IGNORECASE
|
| 73 |
)
|
| 74 |
|
| 75 |
+
# Extract and remove social handles (Twitter, Instagram, FaceBook)
|
| 76 |
social = re.findall(r"(Twitter|Instagram|FaceBook):\s*(\S+)", cleaned_text)
|
| 77 |
for platform, handle in social:
|
| 78 |
if handle.startswith("http"):
|
| 79 |
metadata[platform.lower()] = handle.strip()
|
| 80 |
else:
|
| 81 |
metadata[f"{platform.lower()}_handle"] = f"https://{platform.lower()}.com/{handle.strip()}"
|
|
|
|
| 82 |
cleaned_text = re.sub(
|
| 83 |
rf"{platform}:\s*{re.escape(handle)}",
|
| 84 |
"",
|
|
|
|
| 86 |
flags=re.IGNORECASE
|
| 87 |
)
|
| 88 |
|
| 89 |
+
# Extract and remove Working Areas in LA
|
| 90 |
+
working_match = re.search(
|
| 91 |
+
r"Working Areas in LA:\s*(.*?)\s+(?=Summary:|Ranking:|Impact Metrics:|$)",
|
| 92 |
+
cleaned_text,
|
| 93 |
+
re.IGNORECASE | re.DOTALL
|
| 94 |
+
)
|
| 95 |
+
if working_match:
|
| 96 |
+
metadata["working_areas"] = working_match.group(1).strip()
|
| 97 |
+
cleaned_text = re.sub(
|
| 98 |
+
r"Working Areas in LA:\s*.*?(?=Summary:|Ranking:|Impact Metrics:|$)",
|
| 99 |
+
"",
|
| 100 |
+
cleaned_text,
|
| 101 |
+
flags=re.IGNORECASE | re.DOTALL
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Extract and remove Zipcode (assuming 5-digit US zipcodes)
|
| 105 |
+
zipcode_match = re.search(r"Zipcode:\s*(\d{5})", cleaned_text, re.IGNORECASE)
|
| 106 |
+
if zipcode_match:
|
| 107 |
+
metadata["zipcode"] = zipcode_match.group(1).strip()
|
| 108 |
+
cleaned_text = re.sub(r"Zipcode:\s*\d{5}", "", cleaned_text, flags=re.IGNORECASE)
|
| 109 |
+
|
| 110 |
# Clean up extra whitespace
|
| 111 |
cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
|
| 112 |
|
| 113 |
+
# Create a metadata summary to append to the cleaned text.
|
| 114 |
+
meta_summary = ""
|
| 115 |
+
if "year" in metadata:
|
| 116 |
+
meta_summary += f"Year: {metadata['year']}. "
|
| 117 |
+
if "ranking" in metadata:
|
| 118 |
+
meta_summary += f"Ranking: {metadata['ranking']}. "
|
| 119 |
+
if "organization" in metadata:
|
| 120 |
+
meta_summary += f"Organization: {metadata['organization']}. "
|
| 121 |
+
if "working_areas" in metadata:
|
| 122 |
+
meta_summary += f"Working Areas in LA: {metadata['working_areas']}. "
|
| 123 |
+
if "zipcode" in metadata:
|
| 124 |
+
meta_summary += f"Zipcode: {metadata['zipcode']}. "
|
| 125 |
+
|
| 126 |
+
combined_text = meta_summary + "\n" + cleaned_text if meta_summary else cleaned_text
|
| 127 |
+
|
| 128 |
+
return metadata, combined_text
|
| 129 |
|
| 130 |
|
| 131 |
def load_and_process_data(file_path: str):
|
| 132 |
+
"""
|
| 133 |
+
Loads JSON data from a file, extracts organization text and metadata (including working areas and zipcode),
|
| 134 |
+
cleans the text by removing redundant metadata, and returns a list of Documents.
|
| 135 |
+
Documents with a "winner" ranking are inserted at the beginning of the list.
|
| 136 |
+
"""
|
| 137 |
try:
|
| 138 |
data = json.loads(Path(file_path).read_text(encoding='utf-8'))
|
| 139 |
docs = []
|
|
|
|
| 141 |
org_text = entry.get("OrganizationText", "")
|
| 142 |
if not org_text:
|
| 143 |
continue
|
| 144 |
+
metadata, combined_text = extract_metadata(org_text)
|
| 145 |
if metadata.get("ranking", "").lower() == "winner":
|
| 146 |
+
docs.insert(0, Document(page_content=combined_text, metadata=metadata))
|
| 147 |
else:
|
| 148 |
+
docs.append(Document(page_content=combined_text, metadata=metadata))
|
| 149 |
return docs
|
| 150 |
except Exception as e:
|
| 151 |
print(f"Error loading JSON: {e}")
|
| 152 |
return []
|
| 153 |
+
|
| 154 |
# -------------------------------
|
| 155 |
# Data Loading and Preprocessing
|
| 156 |
# -------------------------------
|
|
|
|
| 159 |
docs = load_and_process_data(file_path)
|
| 160 |
|
| 161 |
# Use a text splitter to create chunks from the documents.
|
|
|
|
| 162 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 163 |
text_splitter = RecursiveCharacterTextSplitter(
|
| 164 |
+
chunk_size=1800,
|
| 165 |
+
chunk_overlap=200,
|
| 166 |
add_start_index=True
|
| 167 |
)
|
| 168 |
all_splits = text_splitter.split_documents(docs)
|
|
|
|
| 171 |
# Set Up Retrievers
|
| 172 |
# -------------------------------
|
| 173 |
|
| 174 |
+
from langchain_chroma import Chroma
|
| 175 |
+
from langchain_openai import OpenAIEmbeddings
|
| 176 |
+
from langchain_community.retrievers import BM25Retriever
|
| 177 |
+
from langchain.retrievers import EnsembleRetriever
|
| 178 |
+
|
| 179 |
persist_directory = "./chroma_db"
|
| 180 |
if os.path.exists(persist_directory) and os.listdir(persist_directory):
|
| 181 |
vectorstore = Chroma(
|
|
|
|
| 191 |
)
|
| 192 |
print("Created new vector store and persisted embeddings.")
|
| 193 |
|
|
|
|
| 194 |
bm25_retriever = BM25Retriever.from_documents(all_splits)
|
|
|
|
|
|
|
| 195 |
ensemble_retriever = EnsembleRetriever(
|
| 196 |
retrievers=[vectorstore.as_retriever(search_kwargs={"k": 6}), bm25_retriever],
|
| 197 |
+
weights=[0.7, 0.3]
|
| 198 |
)
|
| 199 |
retriever = ensemble_retriever
|
| 200 |
|
|
|
|
| 202 |
# Prepare Retrieval and Generation Chain
|
| 203 |
# -------------------------------
|
| 204 |
|
| 205 |
+
from langchain_openai import ChatOpenAI
|
| 206 |
+
from langchain.chains import create_retrieval_chain
|
| 207 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 208 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 209 |
|
| 210 |
+
system_prompt = (
|
| 211 |
"You are the LA2050 Navigator, an AI-powered chatbot designed to help users explore organizations and community initiatives within the Goldhirsh Foundation’s LA2050 Ideas Hub. "
|
|
|
|
| 212 |
"Your role is to provide concise, personalized recommendations, guide users toward supporting these organizations and initiatives, and answer relevant questions about the Goldhirsh Foundation, LA2050, and its projects. "
|
|
|
|
| 213 |
"When answering, include the full name of the organization, a brief (1-2 sentence) description, and a link to its website or social media (as provided under the website column; please do not alter or normalize the URL). "
|
|
|
|
| 214 |
"If a company's personal website is unavailable, navigate to the LA2050 URLs. "
|
|
|
|
| 215 |
"Prioritize nonprofit organizations awarded by the Goldhirsh Foundation (designated 'Winner' under ranking column) and those with multiple proposal submissions. "
|
|
|
|
| 216 |
"Use the data files as your primary source of information. If information is unavailable, acknowledge it and guide the user to relevant resources. "
|
|
|
|
| 217 |
"Maintain a polite, helpful, respectful, and enthusiastic tone at all times. "
|
|
|
|
| 218 |
"If the user responds with a follow-up confirmation (e.g. 'yes') after a previous answer, please expand on that topic with additional information. "
|
| 219 |
+
"When answering questions about grant winners, only list organizations whose metadata ranking field is marked as 'Winner'."
|
|
|
|
| 220 |
"\n\n{context}"
|
|
|
|
| 221 |
)
|
| 222 |
|
|
|
|
| 223 |
prompt = ChatPromptTemplate.from_messages(
|
| 224 |
[
|
| 225 |
("system", system_prompt),
|
|
|
|
| 227 |
]
|
| 228 |
)
|
| 229 |
|
|
|
|
| 230 |
question_answer_chain = create_stuff_documents_chain(ChatOpenAI(model_name="gpt-4o-mini", temperature=0), prompt)
|
| 231 |
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
|
| 232 |
|
|
|
|
| 234 |
# Gradio Interface and Conversation Handling
|
| 235 |
# -------------------------------
|
| 236 |
|
| 237 |
+
import gradio as gr
|
| 238 |
+
import time
|
| 239 |
+
from dotenv import load_dotenv
|
| 240 |
+
load_dotenv()
|
| 241 |
+
|
| 242 |
green_theme = gr.themes.Base(
|
| 243 |
primary_hue=gr.themes.Color(
|
| 244 |
c50="#00A168",
|
|
|
|
| 271 |
)
|
| 272 |
|
| 273 |
def message_and_history(message, history):
|
|
|
|
| 274 |
history = history or [{"role": "assistant", "content": "<b>LA2050 Navigator:</b><br> Welcome to the LA2050 ideas hub! How can I help you today?"}]
|
| 275 |
user_text = message.get("text", "")
|
| 276 |
history.append({"role": "user", "content": user_text})
|
| 277 |
|
| 278 |
time.sleep(1)
|
|
|
|
|
|
|
| 279 |
if not user_text:
|
| 280 |
history.append({"role": "assistant", "content": "<b>LA2050 Navigator:</b><br> Please enter a valid message."})
|
| 281 |
yield history, history
|
| 282 |
return
|
| 283 |
|
|
|
|
| 284 |
conversation_context = "\n".join(
|
| 285 |
[f"{msg['role']}: {msg['content'].replace('<b>LA2050 Navigator:</b><br>', '')}" for msg in history[-3:]]
|
| 286 |
)
|
|
|
|
| 287 |
chain_input = {"input": conversation_context}
|
| 288 |
|
| 289 |
try:
|
|
|
|
| 291 |
answer = response["answer"]
|
| 292 |
except Exception as e:
|
| 293 |
answer = f"An error occurred: {e}"
|
| 294 |
+
|
|
|
|
| 295 |
if answer.startswith("<b>LA2050 Navigator:</b><br>"):
|
| 296 |
answer = answer[len("<b>LA2050 Navigator:</b><br>"):]
|
| 297 |
+
|
|
|
|
| 298 |
assistant_response = {"role": "assistant", "content": "<b>LA2050 Navigator:</b><br> "}
|
| 299 |
history.append(assistant_response)
|
| 300 |
+
|
|
|
|
| 301 |
for character in answer:
|
| 302 |
assistant_response["content"] += character
|
| 303 |
yield history, history
|
| 304 |
+
|
|
|
|
| 305 |
history[-1]["content"] = assistant_response["content"]
|
| 306 |
yield history, history
|
| 307 |
|
|
|
|
| 308 |
js_func = """
|
| 309 |
function refresh() {
|
| 310 |
const url = new URL(window.location);
|
|
|
|
| 347 |
show_label=False
|
| 348 |
)
|
| 349 |
|
|
|
|
| 350 |
message.submit(
|
| 351 |
message_and_history,
|
| 352 |
inputs=[message, state],
|
|
|
|
| 356 |
)
|
| 357 |
|
| 358 |
block.launch(debug=True, share=True)
|
| 359 |
+
|
| 360 |
+
|