Spaces:
Sleeping
Sleeping
Update openai_rag_chatbot.py
Browse files- openai_rag_chatbot.py +8 -10
openai_rag_chatbot.py
CHANGED
|
@@ -6,7 +6,6 @@ import re
|
|
| 6 |
import warnings
|
| 7 |
import logging
|
| 8 |
import streamlit as st
|
| 9 |
-
from dotenv import load_dotenv
|
| 10 |
from langdetect import detect
|
| 11 |
|
| 12 |
from langchain_community.chat_models import ChatOpenAI
|
|
@@ -17,8 +16,9 @@ from langchain.indexes import VectorstoreIndexCreator
|
|
| 17 |
from langchain.chains import RetrievalQA
|
| 18 |
from langchain_core.prompts import ChatPromptTemplate
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
|
|
|
|
| 22 |
warnings.filterwarnings("ignore")
|
| 23 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
| 24 |
|
|
@@ -42,7 +42,7 @@ def detect_language(text):
|
|
| 42 |
|
| 43 |
@st.cache_resource
|
| 44 |
def get_vectorstore():
|
| 45 |
-
pdf_path = "bad_breisig_docs.pdf" #
|
| 46 |
loaders = [PyPDFLoader(pdf_path)]
|
| 47 |
return VectorstoreIndexCreator(
|
| 48 |
embedding=HuggingFaceEmbeddings(model_name='all-MiniLM-L12-v2'),
|
|
@@ -52,8 +52,8 @@ def get_vectorstore():
|
|
| 52 |
).from_loaders(loaders).vectorstore
|
| 53 |
|
| 54 |
def extract_political_groups(text):
|
| 55 |
-
pattern = re.compile(r'
|
| 56 |
-
return '\n'.join([m.group(0).strip() for m in pattern.finditer(text)])
|
| 57 |
|
| 58 |
prompt = st.chat_input('Pass your prompt here')
|
| 59 |
|
|
@@ -65,7 +65,7 @@ if prompt:
|
|
| 65 |
openai_chat = ChatOpenAI(
|
| 66 |
model_name="gpt-3.5-turbo",
|
| 67 |
temperature=0,
|
| 68 |
-
openai_api_key=
|
| 69 |
)
|
| 70 |
|
| 71 |
lang = detect_language(prompt)
|
|
@@ -101,7 +101,6 @@ Context:
|
|
| 101 |
result = chain({"query": prompt})
|
| 102 |
response = result["result"].strip()
|
| 103 |
|
| 104 |
-
# Special case: Political group list
|
| 105 |
if any(x in prompt.lower() for x in ["partei", "gruppierung", "gruppen", "parties", "political"]):
|
| 106 |
fallback_docs = result.get("source_documents", [])
|
| 107 |
combined_text = "\n".join(doc.page_content for doc in fallback_docs)
|
|
@@ -109,7 +108,6 @@ Context:
|
|
| 109 |
if filtered:
|
| 110 |
response = f"Die politischen Gruppierungen in Bad Breisig sind:\n\n{filtered}" if lang == "de" else f"The political groups in Bad Breisig are:\n\n{filtered}"
|
| 111 |
|
| 112 |
-
# Fallback if no relevant info
|
| 113 |
if not response or "not found" in response.lower() or "nicht im kontext" in response.lower():
|
| 114 |
fallback_docs = vectorstore.similarity_search_with_score(prompt, k=3)
|
| 115 |
keyword_hits = list({doc.page_content.strip()[:300] for doc, _ in fallback_docs})
|
|
@@ -124,4 +122,4 @@ Context:
|
|
| 124 |
st.session_state.messages.append({'role': 'assistant', 'content': response})
|
| 125 |
|
| 126 |
except Exception as e:
|
| 127 |
-
st.error(f"❌ Error: {str(e)}")
|
|
|
|
| 6 |
import warnings
|
| 7 |
import logging
|
| 8 |
import streamlit as st
|
|
|
|
| 9 |
from langdetect import detect
|
| 10 |
|
| 11 |
from langchain_community.chat_models import ChatOpenAI
|
|
|
|
| 16 |
from langchain.chains import RetrievalQA
|
| 17 |
from langchain_core.prompts import ChatPromptTemplate
|
| 18 |
|
| 19 |
+
# 🔐 Embed your API key directly for Streamlit Cloud deployment
|
| 20 |
+
OPENAI_API_KEY = "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" # Replace with your actual key
|
| 21 |
+
|
| 22 |
warnings.filterwarnings("ignore")
|
| 23 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
| 24 |
|
|
|
|
| 42 |
|
| 43 |
@st.cache_resource
|
| 44 |
def get_vectorstore():
|
| 45 |
+
pdf_path = "bad_breisig_docs.pdf" # Ensure this path is correct on Streamlit Cloud
|
| 46 |
loaders = [PyPDFLoader(pdf_path)]
|
| 47 |
return VectorstoreIndexCreator(
|
| 48 |
embedding=HuggingFaceEmbeddings(model_name='all-MiniLM-L12-v2'),
|
|
|
|
| 52 |
).from_loaders(loaders).vectorstore
|
| 53 |
|
| 54 |
def extract_political_groups(text):
|
| 55 |
+
pattern = re.compile(r'(AsF|CDU|SPD|FDP|Junge Union|Senioren-Union|Freie W[aä]hlergruppe)[^\n]*', re.IGNORECASE)
|
| 56 |
+
return '\n'.join(sorted(set([m.group(0).strip() for m in pattern.finditer(text)])))
|
| 57 |
|
| 58 |
prompt = st.chat_input('Pass your prompt here')
|
| 59 |
|
|
|
|
| 65 |
openai_chat = ChatOpenAI(
|
| 66 |
model_name="gpt-3.5-turbo",
|
| 67 |
temperature=0,
|
| 68 |
+
openai_api_key=OPENAI_API_KEY
|
| 69 |
)
|
| 70 |
|
| 71 |
lang = detect_language(prompt)
|
|
|
|
| 101 |
result = chain({"query": prompt})
|
| 102 |
response = result["result"].strip()
|
| 103 |
|
|
|
|
| 104 |
if any(x in prompt.lower() for x in ["partei", "gruppierung", "gruppen", "parties", "political"]):
|
| 105 |
fallback_docs = result.get("source_documents", [])
|
| 106 |
combined_text = "\n".join(doc.page_content for doc in fallback_docs)
|
|
|
|
| 108 |
if filtered:
|
| 109 |
response = f"Die politischen Gruppierungen in Bad Breisig sind:\n\n{filtered}" if lang == "de" else f"The political groups in Bad Breisig are:\n\n{filtered}"
|
| 110 |
|
|
|
|
| 111 |
if not response or "not found" in response.lower() or "nicht im kontext" in response.lower():
|
| 112 |
fallback_docs = vectorstore.similarity_search_with_score(prompt, k=3)
|
| 113 |
keyword_hits = list({doc.page_content.strip()[:300] for doc, _ in fallback_docs})
|
|
|
|
| 122 |
st.session_state.messages.append({'role': 'assistant', 'content': response})
|
| 123 |
|
| 124 |
except Exception as e:
|
| 125 |
+
st.error(f"❌ Error: {str(e)}")
|