fsojni commited on
Commit
f1365f3
·
verified ·
1 Parent(s): 5dce132

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -18
app.py CHANGED
@@ -19,6 +19,18 @@ CHAT_MODEL_ID = "NousResearch/Meta-Llama-3-8B-Instruct"
19
  EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1"
20
  MAX_PROMPT_TOKENS = 8192
21
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # --- lazy loaders (unchanged) -------------------------------------------------
23
  tokenizer, chat_model = None, None
24
  emb_tokenizer, emb_model = None, None
@@ -60,13 +72,13 @@ def embed(text:str)->torch.Tensor:
60
 
61
  kb = defaultdict(lambda: {"texts": [], "vecs": None})
62
 
63
- def add_docs(user_id: str, docs: list[str]) -> int:
64
- """Embed *docs* and append them to the KB for *user_id*.
65
- Returns the number of docs actually stored."""
66
- docs = [t for t in docs if t.strip()] # skip blanks
67
- if not docs:
68
- return 0
69
 
 
 
 
 
 
70
  load_embedder() # lazy-load once
71
  new_vecs = torch.stack([embed(t) for t in docs]).cpu()
72
  store = kb[user_id] # auto-creates via defaultdict
@@ -117,10 +129,9 @@ def build_llm_prompt(system: str, context: list[str], user_question: str) -> str
117
  return prompt
118
 
119
  # ---------- 4. Gradio playground (same UI as before) --------------------------
120
- def store_doc(doc_text: str, user_id="demo"):
121
- """UI callback: take the textbox content and shove it into the KB."""
122
  try:
123
- n = add_docs(user_id, [doc_text])
124
  if n == 0:
125
  return "Nothing stored (empty input)."
126
  return f"Stored — KB now has {len(kb[user_id]['texts'])} passage(s)."
@@ -128,7 +139,11 @@ def store_doc(doc_text: str, user_id="demo"):
128
  return f"Error during storing: {e}"
129
 
130
  import traceback
131
- def answer(system: str, context: str, question: str, user_id="demo", history="None"):
 
 
 
 
132
  """UI callback: retrieve, build prompt with Qwen tags, generate answer."""
133
  try:
134
  if not question.strip():
@@ -175,6 +190,10 @@ def answer(system: str, context: str, question: str, user_id="demo", history="No
175
  **tokens,
176
  max_new_tokens=512,
177
  max_length=MAX_PROMPT_TOKENS + 512,
 
 
 
 
178
  )
179
  full = tokenizer.decode(output[0], skip_special_tokens=True)
180
  reply = full.split("<|im_start|>assistant")[-1].strip()
@@ -203,7 +222,13 @@ with gr.Blocks() as demo:
203
  with gr.Row():
204
  passage_box = gr.Textbox(lines=6, label="Reference passage")
205
  user_id_box = gr.Textbox(value="demo", label="User ID")
206
- store_btn = gr.Button("Store passage")
 
 
 
 
 
 
207
  clear_btn = gr.Button("Clear KB")
208
 
209
  status_box = gr.Markdown()
@@ -216,16 +241,27 @@ with gr.Blocks() as demo:
216
 
217
  # ---- Q & A ----
218
  question_box = gr.Textbox(lines=2, label="Ask a question")
219
- history_cb = gr.Textbox(value="None", label="Use chat history")
220
- system_box = gr.Textbox(lines=2, label="System prompt")
221
- context_box = gr.Textbox(lines=6, label="Context passages (each line one passage)")
222
-
223
- answer_btn = gr.Button("Answer")
224
- answer_box = gr.Textbox(lines=6, label="Assistant reply")
 
 
 
 
 
 
 
 
 
225
 
226
  answer_btn.click(
227
  fn=answer,
228
- inputs=[system_box, context_box, question_box, user_id_box, history_cb],
 
 
229
  outputs=answer_box
230
  )
231
 
 
19
  EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1"
20
  MAX_PROMPT_TOKENS = 8192
21
 
22
+ # ---------- new defaults & helper ------------------
23
+ DEFAULT_TEMP = 0.7
24
+ DEFAULT_TOP_P = 0.9
25
+ DEFAULT_TOP_K_TOK = 40 # token-level sampling
26
+ DEFAULT_CHUNK_SIZE = 512 # characters
27
+ DEFAULT_CHUNK_OVERLAP = 128
28
+
29
+ def chunk_text(text: str, size: int, overlap: int):
30
+ """Yield sliding-window chunks of *text* with character overlap."""
31
+ for start in range(0, len(text), size - overlap):
32
+ yield text[start : start + size]
33
+
34
  # --- lazy loaders (unchanged) -------------------------------------------------
35
  tokenizer, chat_model = None, None
36
  emb_tokenizer, emb_model = None, None
 
72
 
73
  kb = defaultdict(lambda: {"texts": [], "vecs": None})
74
 
75
+ def add_docs(user_id: str,docs: list[str],chunk_size: int = DEFAULT_CHUNK_SIZE,chunk_overlap: int = DEFAULT_CHUNK_OVERLAP) -> int:
 
 
 
 
 
76
 
77
+ # ---------- NEW ----------
78
+ chunks = []
79
+ for d in docs:
80
+ chunks.extend(chunk_text(d, chunk_size, chunk_overlap))
81
+ docs = [c for c in chunks if c.strip()]
82
  load_embedder() # lazy-load once
83
  new_vecs = torch.stack([embed(t) for t in docs]).cpu()
84
  store = kb[user_id] # auto-creates via defaultdict
 
129
  return prompt
130
 
131
  # ---------- 4. Gradio playground (same UI as before) --------------------------
132
+ def store_doc(doc_text: str,user_id="demo",chunk_size=DEFAULT_CHUNK_SIZE,chunk_overlap=DEFAULT_CHUNK_OVERLAP):
 
133
  try:
134
+ n = add_docs(user_id, [doc_text], chunk_size, chunk_overlap)
135
  if n == 0:
136
  return "Nothing stored (empty input)."
137
  return f"Stored — KB now has {len(kb[user_id]['texts'])} passage(s)."
 
139
  return f"Error during storing: {e}"
140
 
141
  import traceback
142
+ def answer(system: str, context: str, question: str,
143
+ user_id="demo", history="None",
144
+ temperature=DEFAULT_TEMP,
145
+ top_p=DEFAULT_TOP_P,
146
+ top_k_tok=DEFAULT_TOP_K_TOK):
147
  """UI callback: retrieve, build prompt with Qwen tags, generate answer."""
148
  try:
149
  if not question.strip():
 
190
  **tokens,
191
  max_new_tokens=512,
192
  max_length=MAX_PROMPT_TOKENS + 512,
193
+ do_sample=True,
194
+ temperature=temperature,
195
+ top_p=top_p,
196
+ top_k=top_k_tok
197
  )
198
  full = tokenizer.decode(output[0], skip_special_tokens=True)
199
  reply = full.split("<|im_start|>assistant")[-1].strip()
 
222
  with gr.Row():
223
  passage_box = gr.Textbox(lines=6, label="Reference passage")
224
  user_id_box = gr.Textbox(value="demo", label="User ID")
225
+ chunk_box = gr.Slider(128, 2048, value=DEFAULT_CHUNK_SIZE,
226
+ step=64, label="Chunk size (chars)")
227
+ overlap_box = gr.Slider(0, 1024, value=DEFAULT_CHUNK_OVERLAP,
228
+ step=32, label="Chunk overlap")
229
+ store_btn.click(fn=store_doc,
230
+ inputs=[passage_box, user_id_box, chunk_box, overlap_box],
231
+ outputs=status_box)
232
  clear_btn = gr.Button("Clear KB")
233
 
234
  status_box = gr.Markdown()
 
241
 
242
  # ---- Q & A ----
243
  question_box = gr.Textbox(lines=2, label="Ask a question")
244
+ history_cb = gr.Textbox(value="None", label="Use chat history")
245
+ system_box = gr.Textbox(lines=2, label="System prompt")
246
+ context_box = gr.Textbox(lines=6, label="Context passages")
247
+
248
+ # ---------- NEW sampling sliders ----------
249
+ temp_box = gr.Slider(0.0, 1.5, value=DEFAULT_TEMP,
250
+ step=0.05, label="Temperature")
251
+ topp_box = gr.Slider(0.0, 1.0, value=DEFAULT_TOP_P,
252
+ step=0.01, label="Top-p")
253
+ topk_box = gr.Slider(1, 100, value=DEFAULT_TOP_K_TOK,
254
+ step=1, label="Top-k (tokens)")
255
+ # ---------- /NEW ----------
256
+
257
+ answer_btn = gr.Button("Answer")
258
+ answer_box = gr.Textbox(lines=6, label="Assistant reply")
259
 
260
  answer_btn.click(
261
  fn=answer,
262
+ inputs=[system_box, context_box, question_box,
263
+ user_id_box, history_cb,
264
+ temp_box, topp_box, topk_box],
265
  outputs=answer_box
266
  )
267