decodingdatascience commited on
Commit
ef780a9
·
verified ·
1 Parent(s): 3a846c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -35
app.py CHANGED
@@ -41,18 +41,13 @@ FAQ_ITEMS = [
41
 
42
  LOGO_RAW_URL = "https://raw.githubusercontent.com/Decoding-Data-Science/airesidency/main/dds-logo-removebg-preview.png"
43
 
44
- # PDFs live in repo under ./data/pdfs
45
  PDF_DIR = Path("data/pdfs")
46
 
47
- # Use persistent disk if available
48
  PERSIST_ROOT = Path("/data") if Path("/data").exists() else Path(".")
49
  VDB_DIR = PERSIST_ROOT / "chroma"
50
 
51
- # Optional HF speed optimization when persistent disk exists
52
- # (HF docs mention setting HF_HOME to /data/.huggingface to speed restarts)
53
- if Path("/data").exists():
54
- os.environ.setdefault("HF_HOME", "/data/.huggingface")
55
-
56
  # -----------------------------
57
  # Helpers
58
  # -----------------------------
@@ -74,23 +69,21 @@ def download_logo() -> str | None:
74
  return None
75
 
76
  def build_or_load_index():
77
- # Guard: ensure OpenAI key exists
78
  if not os.getenv("OPENAI_API_KEY"):
79
  raise RuntimeError("OPENAI_API_KEY is not set. Add it in Space Settings → Repository secrets.")
80
 
81
  if not PDF_DIR.exists():
82
- raise RuntimeError(f"PDF folder not found: {PDF_DIR}. Add your PDFs under data/pdfs/ in the Space repo.")
83
 
84
  pdfs = sorted(PDF_DIR.glob("*.pdf"))
85
  if not pdfs:
86
- raise RuntimeError(f"No PDFs found in {PDF_DIR}. Upload your 4 HR PDFs there.")
87
 
88
  # LlamaIndex settings
89
  Settings.embed_model = OpenAIEmbedding(model=EMBED_MODEL)
90
  Settings.llm = LIOpenAI(model=LLM_MODEL, temperature=0.0)
91
  Settings.node_parser = SentenceSplitter(chunk_size=900, chunk_overlap=150)
92
 
93
- # Read documents
94
  docs = SimpleDirectoryReader(
95
  input_dir=str(PDF_DIR),
96
  required_exts=[".pdf"],
@@ -101,21 +94,22 @@ def build_or_load_index():
101
  VDB_DIR.mkdir(parents=True, exist_ok=True)
102
  chroma_client = chromadb.PersistentClient(path=str(VDB_DIR))
103
 
104
- # Reuse existing collection if present; otherwise create/build
105
  try:
106
  col = chroma_client.get_collection(COLLECTION_NAME)
107
- # If count works and >0, reuse
108
  try:
109
  if col.count() > 0:
110
  vector_store = ChromaVectorStore(chroma_collection=col)
111
  storage_context = StorageContext.from_defaults(vector_store=vector_store)
112
- return VectorStoreIndex.from_vector_store(vector_store=vector_store, storage_context=storage_context)
 
 
113
  except Exception:
114
  pass
115
  except Exception:
116
  pass
117
 
118
- # Create/build fresh
119
  try:
120
  chroma_client.delete_collection(COLLECTION_NAME)
121
  except Exception:
@@ -127,7 +121,23 @@ def build_or_load_index():
127
 
128
  return VectorStoreIndex.from_documents(docs, storage_context=storage_context)
129
 
130
- # Build index at startup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  INDEX = build_or_load_index()
132
 
133
  CHAT_ENGINE = INDEX.as_chat_engine(
@@ -136,7 +146,11 @@ CHAT_ENGINE = INDEX.as_chat_engine(
136
  system_prompt=SYSTEM_PROMPT,
137
  )
138
 
139
- def answer(user_msg: str, history: list[tuple[str, str]], show_sources: bool):
 
 
 
 
140
  user_msg = (user_msg or "").strip()
141
  if not user_msg:
142
  return history, ""
@@ -145,20 +159,13 @@ def answer(user_msg: str, history: list[tuple[str, str]], show_sources: bool):
145
  text = str(resp).strip()
146
 
147
  if show_sources:
148
- srcs = getattr(resp, "source_nodes", None) or []
149
- if srcs:
150
- lines = ["", "Sources:"]
151
- for i, sn in enumerate(srcs[:5], start=1):
152
- md = sn.node.metadata or {}
153
- doc = _md_get(md, ["file_name", "filename", "doc_name", "source"], "unknown_doc")
154
- page = _md_get(md, ["page_label", "page", "page_number"], "?")
155
- score = sn.score if sn.score is not None else float("nan")
156
- lines.append(f"{i}) {doc} | page {page} | score {score:.3f}")
157
- text = text + "\n" + "\n".join(lines)
158
- else:
159
- text = text + "\n\nSources: (none returned)"
160
-
161
- history = history + [(user_msg, text)]
162
  return history, ""
163
 
164
  def load_faq(faq_choice: str):
@@ -168,7 +175,7 @@ def clear_chat():
168
  return [], ""
169
 
170
  # -----------------------------
171
- # Gradio UI
172
  # -----------------------------
173
  logo_path = download_logo()
174
 
@@ -178,8 +185,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
178
  gr.Image(value=logo_path, show_label=False, height=70, width=70, container=False)
179
  gr.Markdown(
180
  "# DDS HR Chatbot (RAG Demo)\n"
181
- "Ask HR policy questions. The assistant answers **only from the provided DDS policy PDFs** "
182
- "and can show sources."
183
  )
184
 
185
  with gr.Row():
@@ -193,7 +199,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
193
  clear_btn = gr.Button("Clear chat")
194
 
195
  with gr.Column(scale=2, min_width=520):
196
- chatbot = gr.Chatbot(label="DDS HR Assistant", height=520)
 
197
  user_input = gr.Textbox(label="Your question", placeholder="Ask a policy question and press Enter")
198
  send_btn = gr.Button("Send")
199
 
 
41
 
42
  LOGO_RAW_URL = "https://raw.githubusercontent.com/Decoding-Data-Science/airesidency/main/dds-logo-removebg-preview.png"
43
 
44
+ # PDFs in repo
45
  PDF_DIR = Path("data/pdfs")
46
 
47
+ # Persistent disk if enabled on Spaces
48
  PERSIST_ROOT = Path("/data") if Path("/data").exists() else Path(".")
49
  VDB_DIR = PERSIST_ROOT / "chroma"
50
 
 
 
 
 
 
51
  # -----------------------------
52
  # Helpers
53
  # -----------------------------
 
69
  return None
70
 
71
  def build_or_load_index():
 
72
  if not os.getenv("OPENAI_API_KEY"):
73
  raise RuntimeError("OPENAI_API_KEY is not set. Add it in Space Settings → Repository secrets.")
74
 
75
  if not PDF_DIR.exists():
76
+ raise RuntimeError(f"PDF folder not found: {PDF_DIR}. Add PDFs under data/pdfs/.")
77
 
78
  pdfs = sorted(PDF_DIR.glob("*.pdf"))
79
  if not pdfs:
80
+ raise RuntimeError(f"No PDFs found in {PDF_DIR}. Upload your HR PDFs there.")
81
 
82
  # LlamaIndex settings
83
  Settings.embed_model = OpenAIEmbedding(model=EMBED_MODEL)
84
  Settings.llm = LIOpenAI(model=LLM_MODEL, temperature=0.0)
85
  Settings.node_parser = SentenceSplitter(chunk_size=900, chunk_overlap=150)
86
 
 
87
  docs = SimpleDirectoryReader(
88
  input_dir=str(PDF_DIR),
89
  required_exts=[".pdf"],
 
94
  VDB_DIR.mkdir(parents=True, exist_ok=True)
95
  chroma_client = chromadb.PersistentClient(path=str(VDB_DIR))
96
 
97
+ # Reuse existing collection if it already has vectors
98
  try:
99
  col = chroma_client.get_collection(COLLECTION_NAME)
 
100
  try:
101
  if col.count() > 0:
102
  vector_store = ChromaVectorStore(chroma_collection=col)
103
  storage_context = StorageContext.from_defaults(vector_store=vector_store)
104
+ return VectorStoreIndex.from_vector_store(
105
+ vector_store=vector_store, storage_context=storage_context
106
+ )
107
  except Exception:
108
  pass
109
  except Exception:
110
  pass
111
 
112
+ # Build fresh collection
113
  try:
114
  chroma_client.delete_collection(COLLECTION_NAME)
115
  except Exception:
 
121
 
122
  return VectorStoreIndex.from_documents(docs, storage_context=storage_context)
123
 
124
+ def format_sources(resp, max_sources=5) -> str:
125
+ srcs = getattr(resp, "source_nodes", None) or []
126
+ if not srcs:
127
+ return "Sources: (none returned)"
128
+
129
+ lines = ["Sources:"]
130
+ for i, sn in enumerate(srcs[:max_sources], start=1):
131
+ md = sn.node.metadata or {}
132
+ doc = _md_get(md, ["file_name", "filename", "doc_name", "source"], "unknown_doc")
133
+ page = _md_get(md, ["page_label", "page", "page_number"], "?")
134
+ score = sn.score if sn.score is not None else float("nan")
135
+ lines.append(f"{i}) {doc} | page {page} | score {score:.3f}")
136
+ return "\n".join(lines)
137
+
138
+ # -----------------------------
139
+ # Build index + chat engine
140
+ # -----------------------------
141
  INDEX = build_or_load_index()
142
 
143
  CHAT_ENGINE = INDEX.as_chat_engine(
 
146
  system_prompt=SYSTEM_PROMPT,
147
  )
148
 
149
+ # -----------------------------
150
+ # Gradio callbacks (MESSAGES format)
151
+ # history is: [{"role":"user","content":"..."}, {"role":"assistant","content":"..."}, ...]
152
+ # -----------------------------
153
+ def answer(user_msg: str, history: list, show_sources: bool):
154
  user_msg = (user_msg or "").strip()
155
  if not user_msg:
156
  return history, ""
 
159
  text = str(resp).strip()
160
 
161
  if show_sources:
162
+ text = text + "\n\n" + format_sources(resp)
163
+
164
+ # Append messages (this fixes your error)
165
+ history = (history or []) + [
166
+ {"role": "user", "content": user_msg},
167
+ {"role": "assistant", "content": text},
168
+ ]
 
 
 
 
 
 
 
169
  return history, ""
170
 
171
  def load_faq(faq_choice: str):
 
175
  return [], ""
176
 
177
  # -----------------------------
178
+ # UI
179
  # -----------------------------
180
  logo_path = download_logo()
181
 
 
185
  gr.Image(value=logo_path, show_label=False, height=70, width=70, container=False)
186
  gr.Markdown(
187
  "# DDS HR Chatbot (RAG Demo)\n"
188
+ "Ask HR policy questions. The assistant answers **only from the DDS HR PDFs** and can show sources."
 
189
  )
190
 
191
  with gr.Row():
 
199
  clear_btn = gr.Button("Clear chat")
200
 
201
  with gr.Column(scale=2, min_width=520):
202
+ # IMPORTANT: type="messages"
203
+ chatbot = gr.Chatbot(label="DDS HR Assistant", height=520, type="messages")
204
  user_input = gr.Textbox(label="Your question", placeholder="Ask a policy question and press Enter")
205
  send_btn = gr.Button("Send")
206