sourize commited on
Commit
873decf
·
1 Parent(s): 13f2322

Updated main.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -14,6 +14,9 @@ def load_resources():
14
  model='google/flan-t5-base',
15
  tokenizer='google/flan-t5-base',
16
  device=-1,
 
 
 
17
  )
18
  return embedder, chat_gen
19
 
@@ -66,7 +69,7 @@ def main():
66
 
67
  # Initialize state
68
  if 'history' not in st.session_state:
69
- st.session_state.history = [] # list of {'role': 'User'|'Assistant', 'text': ...}
70
  if 'chunks' not in st.session_state:
71
  st.session_state.chunks = []
72
  if 'index' not in st.session_state:
@@ -80,8 +83,8 @@ def main():
80
  st.session_state.embedder, st.session_state.chat_gen = load_resources()
81
  st.session_state.index = build_index(st.session_state.chunks, st.session_state.embedder)
82
  st.session_state.uploaded_name = uploaded.name
83
- st.session_state.history = [] # reset conversation
84
- # Load models if not loaded
85
  if 'embedder' not in st.session_state or 'chat_gen' not in st.session_state:
86
  st.session_state.embedder, st.session_state.chat_gen = load_resources()
87
 
@@ -93,23 +96,25 @@ def main():
93
  # Chat input always available
94
  question = st.chat_input('Ask a question—general or document-specific...')
95
  if question:
96
- # Retrieve context if index exists
97
  context = ''
98
  if st.session_state.index is not None:
99
  q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
100
  _, idxs = st.session_state.index.search(q_emb, k=3)
101
  context = '\n\n'.join(st.session_state.chunks[i] for i in idxs[0])
102
 
103
- # Build prompt
104
  system_prompt = (
105
  "You are RagBot, an AI assistant. "
106
- "Use the provided document context for specific questions, "
107
- "and your general knowledge for everything else."
 
 
108
  )
109
  prompt = make_prompt(system_prompt, context, st.session_state.history, question)
110
 
111
  # Generate answer
112
- response = st.session_state.chat_gen(prompt, max_new_tokens=200, do_sample=False)
113
  answer = response[0]['generated_text'].strip()
114
 
115
  # Save & display
 
14
  model='google/flan-t5-base',
15
  tokenizer='google/flan-t5-base',
16
  device=-1,
17
+ # enforce deterministic decoding and low temperature to reduce hallucinations
18
+ do_sample=False,
19
+ temperature=0.0,
20
  )
21
  return embedder, chat_gen
22
 
 
69
 
70
  # Initialize state
71
  if 'history' not in st.session_state:
72
+ st.session_state.history = []
73
  if 'chunks' not in st.session_state:
74
  st.session_state.chunks = []
75
  if 'index' not in st.session_state:
 
83
  st.session_state.embedder, st.session_state.chat_gen = load_resources()
84
  st.session_state.index = build_index(st.session_state.chunks, st.session_state.embedder)
85
  st.session_state.uploaded_name = uploaded.name
86
+ st.session_state.history = []
87
+ # Load models if missing
88
  if 'embedder' not in st.session_state or 'chat_gen' not in st.session_state:
89
  st.session_state.embedder, st.session_state.chat_gen = load_resources()
90
 
 
96
  # Chat input always available
97
  question = st.chat_input('Ask a question—general or document-specific...')
98
  if question:
99
+ # Retrieve context
100
  context = ''
101
  if st.session_state.index is not None:
102
  q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
103
  _, idxs = st.session_state.index.search(q_emb, k=3)
104
  context = '\n\n'.join(st.session_state.chunks[i] for i in idxs[0])
105
 
106
+ # Build prompt with hallucination guard
107
  system_prompt = (
108
  "You are RagBot, an AI assistant. "
109
+ "You must ONLY use the document context provided to answer document-specific questions. "
110
+ "If the answer is not contained in the context, respond with: "
111
+ "\"I’m sorry, I don’t know based on the document.\" "
112
+ "For general knowledge questions, answer using your training knowledge without hallucinating."
113
  )
114
  prompt = make_prompt(system_prompt, context, st.session_state.history, question)
115
 
116
  # Generate answer
117
+ response = st.session_state.chat_gen(prompt, max_new_tokens=200)
118
  answer = response[0]['generated_text'].strip()
119
 
120
  # Save & display