sourize
commited on
Commit
·
873decf
1
Parent(s):
13f2322
Updated main.py
Browse files
app.py
CHANGED
|
@@ -14,6 +14,9 @@ def load_resources():
|
|
| 14 |
model='google/flan-t5-base',
|
| 15 |
tokenizer='google/flan-t5-base',
|
| 16 |
device=-1,
|
|
|
|
|
|
|
|
|
|
| 17 |
)
|
| 18 |
return embedder, chat_gen
|
| 19 |
|
|
@@ -66,7 +69,7 @@ def main():
|
|
| 66 |
|
| 67 |
# Initialize state
|
| 68 |
if 'history' not in st.session_state:
|
| 69 |
-
st.session_state.history = []
|
| 70 |
if 'chunks' not in st.session_state:
|
| 71 |
st.session_state.chunks = []
|
| 72 |
if 'index' not in st.session_state:
|
|
@@ -80,8 +83,8 @@ def main():
|
|
| 80 |
st.session_state.embedder, st.session_state.chat_gen = load_resources()
|
| 81 |
st.session_state.index = build_index(st.session_state.chunks, st.session_state.embedder)
|
| 82 |
st.session_state.uploaded_name = uploaded.name
|
| 83 |
-
st.session_state.history = []
|
| 84 |
-
# Load models if
|
| 85 |
if 'embedder' not in st.session_state or 'chat_gen' not in st.session_state:
|
| 86 |
st.session_state.embedder, st.session_state.chat_gen = load_resources()
|
| 87 |
|
|
@@ -93,23 +96,25 @@ def main():
|
|
| 93 |
# Chat input always available
|
| 94 |
question = st.chat_input('Ask a question—general or document-specific...')
|
| 95 |
if question:
|
| 96 |
-
# Retrieve context
|
| 97 |
context = ''
|
| 98 |
if st.session_state.index is not None:
|
| 99 |
q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
|
| 100 |
_, idxs = st.session_state.index.search(q_emb, k=3)
|
| 101 |
context = '\n\n'.join(st.session_state.chunks[i] for i in idxs[0])
|
| 102 |
|
| 103 |
-
# Build prompt
|
| 104 |
system_prompt = (
|
| 105 |
"You are RagBot, an AI assistant. "
|
| 106 |
-
"
|
| 107 |
-
"
|
|
|
|
|
|
|
| 108 |
)
|
| 109 |
prompt = make_prompt(system_prompt, context, st.session_state.history, question)
|
| 110 |
|
| 111 |
# Generate answer
|
| 112 |
-
response = st.session_state.chat_gen(prompt, max_new_tokens=200
|
| 113 |
answer = response[0]['generated_text'].strip()
|
| 114 |
|
| 115 |
# Save & display
|
|
|
|
| 14 |
model='google/flan-t5-base',
|
| 15 |
tokenizer='google/flan-t5-base',
|
| 16 |
device=-1,
|
| 17 |
+
# enforce deterministic decoding and low temperature to reduce hallucinations
|
| 18 |
+
do_sample=False,
|
| 19 |
+
temperature=0.0,
|
| 20 |
)
|
| 21 |
return embedder, chat_gen
|
| 22 |
|
|
|
|
| 69 |
|
| 70 |
# Initialize state
|
| 71 |
if 'history' not in st.session_state:
|
| 72 |
+
st.session_state.history = []
|
| 73 |
if 'chunks' not in st.session_state:
|
| 74 |
st.session_state.chunks = []
|
| 75 |
if 'index' not in st.session_state:
|
|
|
|
| 83 |
st.session_state.embedder, st.session_state.chat_gen = load_resources()
|
| 84 |
st.session_state.index = build_index(st.session_state.chunks, st.session_state.embedder)
|
| 85 |
st.session_state.uploaded_name = uploaded.name
|
| 86 |
+
st.session_state.history = []
|
| 87 |
+
# Load models if missing
|
| 88 |
if 'embedder' not in st.session_state or 'chat_gen' not in st.session_state:
|
| 89 |
st.session_state.embedder, st.session_state.chat_gen = load_resources()
|
| 90 |
|
|
|
|
| 96 |
# Chat input always available
|
| 97 |
question = st.chat_input('Ask a question—general or document-specific...')
|
| 98 |
if question:
|
| 99 |
+
# Retrieve context
|
| 100 |
context = ''
|
| 101 |
if st.session_state.index is not None:
|
| 102 |
q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
|
| 103 |
_, idxs = st.session_state.index.search(q_emb, k=3)
|
| 104 |
context = '\n\n'.join(st.session_state.chunks[i] for i in idxs[0])
|
| 105 |
|
| 106 |
+
# Build prompt with hallucination guard
|
| 107 |
system_prompt = (
|
| 108 |
"You are RagBot, an AI assistant. "
|
| 109 |
+
"You must ONLY use the document context provided to answer document-specific questions. "
|
| 110 |
+
"If the answer is not contained in the context, respond with: "
|
| 111 |
+
"\"I’m sorry, I don’t know based on the document.\" "
|
| 112 |
+
"For general knowledge questions, answer using your training knowledge without hallucinating."
|
| 113 |
)
|
| 114 |
prompt = make_prompt(system_prompt, context, st.session_state.history, question)
|
| 115 |
|
| 116 |
# Generate answer
|
| 117 |
+
response = st.session_state.chat_gen(prompt, max_new_tokens=200)
|
| 118 |
answer = response[0]['generated_text'].strip()
|
| 119 |
|
| 120 |
# Save & display
|