jfang commited on
Commit
3718631
·
verified ·
1 Parent(s): 8a7f087

Upload 7 files

Browse files
app.py CHANGED
@@ -3,7 +3,13 @@ import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import spaces
5
  import re
6
- from typing import List, Dict, Tuple
 
 
 
 
 
 
7
 
8
 
9
  # Initialize model and tokenizer
@@ -21,6 +27,51 @@ model = AutoModelForCausalLM.from_pretrained(
21
  trust_remote_code=True
22
  )
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  @spaces.GPU(duration=60)
26
  def generate_response_stream(
@@ -191,6 +242,84 @@ def generate_response_stream(
191
  thread.join()
192
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  def respond(
195
  message: str,
196
  history: List[Dict[str, str]],
@@ -200,39 +329,186 @@ def respond(
200
  top_p: float,
201
  ):
202
  """
203
- Response function for custom Gradio interface with separate thinking display.
204
  """
205
- thinking_content = ""
206
- response_content = ""
 
 
207
 
208
  try:
209
- # Stream tokens from the model
 
 
 
 
 
 
 
 
 
210
  for thinking, response in generate_response_stream(
211
  message=message,
212
  history=history,
213
- system_message=system_message,
214
  max_tokens=max_tokens,
215
  temperature=temperature,
216
  top_p=top_p,
217
  ):
218
- thinking_content = thinking
219
- response_content = response
220
- # Yield both thinking and response content
221
- yield thinking_content, response_content
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  except Exception as e:
224
  error_message = f"❌ Error generating response: {str(e)}"
225
- yield "", error_message
226
 
227
 
228
  # Default system prompt for gprMax assistance
229
- DEFAULT_SYSTEM_PROMPT = """You are a helpful assistant specialized in gprMax, an open-source software that simulates electromagnetic wave propagation. You help users with:
 
 
 
230
  1. Creating gprMax input files (.in files)
231
  2. Understanding gprMax commands and syntax
232
  3. Setting up simulations for GPR (Ground Penetrating Radar) and other EM applications
233
  4. Troubleshooting simulation issues
234
  5. Optimizing simulation parameters
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  If you give code blocks, ensure to enclose them inside ```.
237
 
238
  There is no need to always give full input codes, be sure to understand what user needs and intends to do. Some times a simple line of code can do, sometimes user wants explanation rather than codes.
@@ -297,13 +573,21 @@ with gr.Blocks(title="gprMax Support", theme=gr.themes.Ocean()) as demo:
297
  thinking_display = gr.Markdown(
298
  value="*Thinking process will appear here when the AI is reasoning through your question...*",
299
  label="Thinking",
300
- height=400,
 
 
 
 
 
 
 
 
301
  )
302
 
303
  # Settings
304
  with gr.Accordion("⚙️ Settings", open=True):
305
  system_message = gr.Textbox(
306
- value=DEFAULT_SYSTEM_PROMPT,
307
  label="System Message",
308
  lines=5,
309
  info="Customize the assistant's behavior"
@@ -345,7 +629,7 @@ with gr.Blocks(title="gprMax Support", theme=gr.themes.Ocean()) as demo:
345
 
346
  def bot_respond(history, system_msg, max_tok, temp, top_p_val):
347
  if not history or history[-1]["role"] != "user":
348
- yield history, "*No thinking process*"
349
  return
350
 
351
  user_message = history[-1]["content"]
@@ -355,10 +639,12 @@ with gr.Blocks(title="gprMax Support", theme=gr.themes.Ocean()) as demo:
355
  history = history + [{"role": "assistant", "content": ""}]
356
 
357
  thinking_text = ""
 
358
  is_thinking = False
359
  has_main_content = False
 
360
 
361
- for thinking, response in respond(
362
  user_message,
363
  history_for_model,
364
  system_msg,
@@ -368,46 +654,76 @@ with gr.Blocks(title="gprMax Support", theme=gr.themes.Ocean()) as demo:
368
  ):
369
  # Update thinking display
370
  if thinking:
371
- thinking_text = f"## Reasoning Process\n\n{thinking}"
372
- is_thinking = True
373
- else:
 
 
 
 
 
 
 
374
  thinking_text = "*Waiting for response...*"
375
 
 
 
 
 
376
  # Update chat response
377
  if response and response.strip():
378
  # We have actual response content
379
- history[-1]["content"] = response
380
- has_main_content = True
 
 
 
 
 
381
  elif is_thinking and not has_main_content:
382
  # Still thinking, no main response yet
383
  history[-1]["content"] = "🤔 *AI is thinking... Check the right pane for thinking details*"
 
 
384
  elif not response:
385
  # No response yet and no thinking detected
386
  history[-1]["content"] = "⏳ *Generating response...*"
387
 
388
- yield history, thinking_text
389
 
390
  # Event handlers
391
  msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then(
392
  bot_respond,
393
  [chatbot, system_message, max_tokens, temperature, top_p],
394
- [chatbot, thinking_display]
395
  )
396
 
397
  submit_btn.click(user_submit, [msg, chatbot], [msg, chatbot]).then(
398
  bot_respond,
399
  [chatbot, system_message, max_tokens, temperature, top_p],
400
- [chatbot, thinking_display]
 
 
 
 
 
 
 
 
 
401
  )
402
 
403
- clear_btn.click(lambda: ([], "*Thinking process will appear here when the AI is reasoning through your question...*"), outputs=[chatbot, thinking_display])
 
404
 
405
  gr.Markdown(
406
- """
407
  ---
408
  ### About
409
  This assistant uses `jfang/gprmax-ft-Qwen3-4B-Instruct`, a model fine-tuned specifically for gprMax support.
410
 
 
 
411
  **Note**: For best results, be specific about your gprMax version and simulation requirements.
412
  """
413
  )
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import spaces
5
  import re
6
+ from typing import List, Dict, Tuple, Optional
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ # Add rag-db to path for imports
11
+ sys.path.append(str(Path(__file__).parent / "rag-db"))
12
+ from retriever import create_retriever, GprMaxRAGRetriever
13
 
14
 
15
  # Initialize model and tokenizer
 
27
  trust_remote_code=True
28
  )
29
 
30
+ # Initialize RAG retriever
31
+ RAG_DB_PATH = Path(__file__).parent / "rag-db" / "chroma_db"
32
+ retriever: Optional[GprMaxRAGRetriever] = None
33
+
34
+ def generate_database_if_needed():
35
+ """Generate the RAG database if it doesn't exist"""
36
+ if not RAG_DB_PATH.exists():
37
+ print("=" * 60)
38
+ print("RAG database not found. Generating database...")
39
+ print("This is a one-time process and may take a few minutes.")
40
+ print("=" * 60)
41
+
42
+ import subprocess
43
+ try:
44
+ # Run the generation script
45
+ result = subprocess.run(
46
+ ["python", str(Path(__file__).parent / "rag-db" / "generate_db.py")],
47
+ capture_output=True,
48
+ text=True,
49
+ check=True
50
+ )
51
+ print(result.stdout)
52
+ print("✅ Database generated successfully!")
53
+ return True
54
+ except subprocess.CalledProcessError as e:
55
+ print(f"❌ Failed to generate database: {e}")
56
+ if e.stderr:
57
+ print(f"Error output: {e.stderr}")
58
+ return False
59
+ return True
60
+
61
+ # Generate database if needed and load retriever
62
+ if generate_database_if_needed():
63
+ try:
64
+ print(f"Loading RAG database from {RAG_DB_PATH}")
65
+ retriever = create_retriever(db_path=RAG_DB_PATH)
66
+ print("RAG database loaded successfully")
67
+ except Exception as e:
68
+ print(f"Error loading RAG database: {e}")
69
+ print("RAG features will be disabled.")
70
+ retriever = None
71
+ else:
72
+ print("RAG features will be disabled due to database generation failure.")
73
+ retriever = None
74
+
75
 
76
  @spaces.GPU(duration=60)
77
  def generate_response_stream(
 
242
  thread.join()
243
 
244
 
245
+ # Tool definitions in Qwen3 format
246
+ TOOLS = [
247
+ {
248
+ "type": "function",
249
+ "function": {
250
+ "name": "search_documentation",
251
+ "description": "Search gprMax documentation for relevant information about commands, syntax, parameters, or usage",
252
+ "parameters": {
253
+ "type": "object",
254
+ "properties": {
255
+ "query": {
256
+ "type": "string",
257
+ "description": "The search query to find relevant documentation"
258
+ },
259
+ "num_results": {
260
+ "type": "integer",
261
+ "description": "Number of results to return",
262
+ "default": 10
263
+ }
264
+ },
265
+ "required": ["query"]
266
+ }
267
+ }
268
+ }
269
+ ]
270
+
271
+ def format_tools_prompt() -> str:
272
+ """Format tools for inclusion in system prompt"""
273
+ import json
274
+ return json.dumps(TOOLS, indent=2)
275
+
276
+
277
+ def perform_rag_search(query: str, k: int = 10) -> Tuple[str, List[Dict]]:
278
+ """
279
+ Perform RAG search and return formatted context and sources
280
+
281
+ Returns:
282
+ Tuple of (context_for_llm, source_list_for_display)
283
+ """
284
+ if not retriever:
285
+ print(f"[DEBUG] Retriever is None!")
286
+ return "", []
287
+
288
+ try:
289
+ print(f"[DEBUG] Searching for: '{query}' with k={k}")
290
+ # Search for relevant documents
291
+ results = retriever.search(query, k=k)
292
+
293
+ print(f"[DEBUG] Search returned {len(results) if results else 0} results")
294
+ if not results:
295
+ return "", []
296
+
297
+ # Format context for LLM - pass all text content
298
+ context_parts = []
299
+ source_list = []
300
+
301
+ for i, result in enumerate(results, 1):
302
+ # Add full text to context for LLM (up to 1000 chars per doc)
303
+ context_parts.append(f"[Document {i}]: {result.text}")
304
+
305
+ # Add to source list for display (limited preview)
306
+ source_list.append({
307
+ "index": i,
308
+ "source": result.metadata.get("source", "Unknown"),
309
+ "score": result.score,
310
+ "preview": result.text[:150] + "..." if len(result.text) > 150 else result.text
311
+ })
312
+
313
+ context = "\n\n".join(context_parts)
314
+ return context, source_list
315
+
316
+ except Exception as e:
317
+ print(f"[DEBUG] RAG search error: {e}")
318
+ import traceback
319
+ traceback.print_exc()
320
+ return "", []
321
+
322
+
323
  def respond(
324
  message: str,
325
  history: List[Dict[str, str]],
 
329
  top_p: float,
330
  ):
331
  """
332
+ Response function with proper Qwen3 tool calling
333
  """
334
+ import json
335
+ import re
336
+
337
+ sources_content = ""
338
 
339
  try:
340
+ # Use system message as-is (already has tools included)
341
+ system_with_tools = system_message
342
+
343
+ # First, get initial response from model to see if it wants to use tools
344
+ tool_call = None
345
+ accumulated_response = ""
346
+ final_thinking = ""
347
+ is_complete = False
348
+
349
+ # Collect the full response (thinking + potential tool call)
350
  for thinking, response in generate_response_stream(
351
  message=message,
352
  history=history,
353
+ system_message=system_with_tools,
354
  max_tokens=max_tokens,
355
  temperature=temperature,
356
  top_p=top_p,
357
  ):
358
+ final_thinking = thinking if thinking else final_thinking
359
+ accumulated_response = response
 
 
360
 
361
+ # Show thinking progress only
362
+ if thinking:
363
+ yield thinking, "⏳ *AI is analyzing your request...*", sources_content
364
+
365
+ # After streaming completes, check what we got
366
+ if accumulated_response and accumulated_response.strip():
367
+ # Check if the complete response is a JSON tool call
368
+ if accumulated_response.strip().startswith('{'):
369
+ try:
370
+ # Try to parse the entire response as JSON
371
+ response_json = json.loads(accumulated_response.strip())
372
+ if "tool_call" in response_json or ("thought" in response_json and "tool_call" in response_json):
373
+ tool_call = response_json.get("tool_call") or response_json["tool_call"]
374
+ # Show status that we're processing the tool call
375
+ yield final_thinking, "🔍 *Processing documentation search request...*", sources_content
376
+ is_complete = True
377
+ except json.JSONDecodeError:
378
+ # Invalid JSON, treat as normal response
379
+ yield final_thinking, accumulated_response, sources_content
380
+ is_complete = True
381
+ except Exception:
382
+ yield final_thinking, accumulated_response, sources_content
383
+ is_complete = True
384
+ else:
385
+ # It's a normal text response, not a tool call
386
+ yield final_thinking, accumulated_response, sources_content
387
+ is_complete = True
388
+
389
+ # If tool was called, execute it
390
+ if tool_call and retriever:
391
+ tool_name = tool_call.get("name")
392
+ print(f"[DEBUG] Tool called: {tool_name}")
393
+ print(f"[DEBUG] Tool call details: {tool_call}")
394
+
395
+ if tool_name == "search_documentation":
396
+ # Update status
397
+ yield "🔍 *Searching documentation...*", "⏳ *Preparing to search...*", "📚 *Retrieving relevant documents...*"
398
+
399
+ # Get search query
400
+ query = tool_call.get("arguments", {}).get("query", message)
401
+ num_results = tool_call.get("arguments", {}).get("num_results", 10)
402
+ print(f"[DEBUG] Query extracted: '{query}', num_results: {num_results}")
403
+
404
+ # Perform search
405
+ context, sources_list = perform_rag_search(query, k=num_results)
406
+ print(f"[DEBUG] Search results - Context length: {len(context)}, Sources: {len(sources_list)}")
407
+
408
+ if context:
409
+ # Format sources for display
410
+ if sources_list:
411
+ sources_parts = ["## 📚 Documentation Sources\n"]
412
+ for source in sources_list:
413
+ sources_parts.append(
414
+ f"**[{source['index']}] {source['source']}** (Score: {source['score']:.3f})\n"
415
+ f"```\n{source['preview']}\n```\n"
416
+ )
417
+ sources_content = "\n".join(sources_parts)
418
+ else:
419
+ sources_content = "*No relevant documentation found*"
420
+
421
+ yield "✅ *Documentation retrieved*", "⏳ *Generating response with context...*", sources_content
422
+
423
+ # Now generate response with the retrieved context
424
+ augmented_message = f"""Tool call result for search_documentation:
425
+
426
+ {context}
427
+
428
+ Original question: {message}
429
+
430
+ Please provide a comprehensive answer based on the documentation above."""
431
+
432
+ # Generate final response with context
433
+ for thinking, response in generate_response_stream(
434
+ message=augmented_message,
435
+ history=history,
436
+ system_message=system_message, # Use original system message for final response
437
+ max_tokens=max_tokens,
438
+ temperature=temperature,
439
+ top_p=top_p,
440
+ ):
441
+ yield thinking, response, sources_content
442
+ else:
443
+ sources_content = "*No relevant documentation found*"
444
+ yield final_thinking, "⚠️ *Unable to retrieve documentation. Providing general answer...*", sources_content
445
+
446
+ # Generate response without documentation context
447
+ fallback_message = f"""The user asked about: {message}
448
+
449
+ No relevant documentation was found in the database. Please provide a helpful answer based on your general knowledge of gprMax."""
450
+
451
+ for thinking, response in generate_response_stream(
452
+ message=fallback_message,
453
+ history=history,
454
+ system_message=system_message,
455
+ max_tokens=max_tokens,
456
+ temperature=temperature,
457
+ top_p=top_p,
458
+ ):
459
+ yield thinking, response, sources_content
460
+ # If tool was called but retriever is not available
461
+ elif tool_call and not retriever:
462
+ yield final_thinking, "⚠️ *Documentation search is not available. Providing answer based on general knowledge...*", ""
463
+
464
+ # Generate response without RAG
465
+ for thinking, response in generate_response_stream(
466
+ message=message,
467
+ history=history,
468
+ system_message=system_message,
469
+ max_tokens=max_tokens,
470
+ temperature=temperature,
471
+ top_p=top_p,
472
+ ):
473
+ yield thinking, response, ""
474
+ # If no tool call and response wasn't already yielded
475
+ elif not tool_call and not is_complete:
476
+ # This shouldn't happen but handle it just in case
477
+ if accumulated_response and not accumulated_response.strip().startswith('{'):
478
+ yield final_thinking, accumulated_response, sources_content
479
+
480
  except Exception as e:
481
  error_message = f"❌ Error generating response: {str(e)}"
482
+ yield "", error_message, ""
483
 
484
 
485
  # Default system prompt for gprMax assistance
486
+ def get_default_system_prompt():
487
+ """Get system prompt with tools formatted"""
488
+ tools_json = format_tools_prompt()
489
+ return f"""You are a helpful assistant specialized in gprMax, an open-source software that simulates electromagnetic wave propagation. You help users with:
490
  1. Creating gprMax input files (.in files)
491
  2. Understanding gprMax commands and syntax
492
  3. Setting up simulations for GPR (Ground Penetrating Radar) and other EM applications
493
  4. Troubleshooting simulation issues
494
  5. Optimizing simulation parameters
495
 
496
+ You have access to the following tools:
497
+ {tools_json}
498
+
499
+ When you need to search documentation, respond with a tool call in this JSON format:
500
+ {{
501
+ "thought": "I need to search the documentation for...",
502
+ "tool_call": {{
503
+ "name": "search_documentation",
504
+ "arguments": {{
505
+ "query": "your search query here"
506
+ }}
507
+ }}
508
+ }}
509
+
510
+ After receiving tool results, provide a comprehensive answer based on the documentation.
511
+
512
  If you give code blocks, ensure to enclose them inside ```.
513
 
514
  There is no need to always give full input codes, be sure to understand what user needs and intends to do. Some times a simple line of code can do, sometimes user wants explanation rather than codes.
 
573
  thinking_display = gr.Markdown(
574
  value="*Thinking process will appear here when the AI is reasoning through your question...*",
575
  label="Thinking",
576
+ height=300,
577
+ )
578
+
579
+ # Documentation sources in collapsible accordion
580
+ with gr.Accordion("📚 Documentation Sources", open=False) as sources_accordion:
581
+ sources_display = gr.Markdown(
582
+ value="*Documentation sources will appear here when RAG search is performed...*",
583
+ label="Sources",
584
+ height=300,
585
  )
586
 
587
  # Settings
588
  with gr.Accordion("⚙️ Settings", open=True):
589
  system_message = gr.Textbox(
590
+ value=get_default_system_prompt(),
591
  label="System Message",
592
  lines=5,
593
  info="Customize the assistant's behavior"
 
629
 
630
  def bot_respond(history, system_msg, max_tok, temp, top_p_val):
631
  if not history or history[-1]["role"] != "user":
632
+ yield history, "*No thinking process*", "*No sources*"
633
  return
634
 
635
  user_message = history[-1]["content"]
 
639
  history = history + [{"role": "assistant", "content": ""}]
640
 
641
  thinking_text = ""
642
+ sources_text = ""
643
  is_thinking = False
644
  has_main_content = False
645
+ is_searching = False
646
 
647
+ for thinking, response, sources in respond(
648
  user_message,
649
  history_for_model,
650
  system_msg,
 
654
  ):
655
  # Update thinking display
656
  if thinking:
657
+ if "Searching documentation" in thinking:
658
+ thinking_text = thinking
659
+ is_searching = True
660
+ elif "Documentation retrieved" in thinking:
661
+ thinking_text = thinking
662
+ is_searching = False
663
+ else:
664
+ thinking_text = f"## Reasoning Process\n\n{thinking}"
665
+ is_thinking = True
666
+ elif not thinking and not is_searching:
667
  thinking_text = "*Waiting for response...*"
668
 
669
+ # Update sources display
670
+ if sources:
671
+ sources_text = sources
672
+
673
  # Update chat response
674
  if response and response.strip():
675
  # We have actual response content
676
+ if "Preparing to search" in response or "Generating response" in response:
677
+ # Status messages
678
+ history[-1]["content"] = response
679
+ else:
680
+ # Actual content
681
+ history[-1]["content"] = response
682
+ has_main_content = True
683
  elif is_thinking and not has_main_content:
684
  # Still thinking, no main response yet
685
  history[-1]["content"] = "🤔 *AI is thinking... Check the right pane for thinking details*"
686
+ elif is_searching:
687
+ history[-1]["content"] = "🔍 *Searching documentation...*"
688
  elif not response:
689
  # No response yet and no thinking detected
690
  history[-1]["content"] = "⏳ *Generating response...*"
691
 
692
+ yield history, thinking_text, sources_text
693
 
694
  # Event handlers
695
  msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then(
696
  bot_respond,
697
  [chatbot, system_message, max_tokens, temperature, top_p],
698
+ [chatbot, thinking_display, sources_display]
699
  )
700
 
701
  submit_btn.click(user_submit, [msg, chatbot], [msg, chatbot]).then(
702
  bot_respond,
703
  [chatbot, system_message, max_tokens, temperature, top_p],
704
+ [chatbot, thinking_display, sources_display]
705
+ )
706
+
707
+ clear_btn.click(
708
+ lambda: (
709
+ [],
710
+ "*Thinking process will appear here when the AI is reasoning through your question...*",
711
+ "*Documentation sources will appear here when RAG search is performed...*"
712
+ ),
713
+ outputs=[chatbot, thinking_display, sources_display]
714
  )
715
 
716
+ # RAG status indicator
717
+ rag_status = "✅ Documentation search enabled" if retriever else "⚠️ Documentation search disabled (run generate_db.py)"
718
 
719
  gr.Markdown(
720
+ f"""
721
  ---
722
  ### About
723
  This assistant uses `jfang/gprmax-ft-Qwen3-4B-Instruct`, a model fine-tuned specifically for gprMax support.
724
 
725
+ **RAG Status**: {rag_status}
726
+
727
  **Note**: For best results, be specific about your gprMax version and simulation requirements.
728
  """
729
  )
rag-db/README.md ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gprMax RAG Database System
2
+
3
+ ## Overview
4
+ This is a production-ready Retrieval-Augmented Generation (RAG) system for gprMax documentation. It provides efficient vector search capabilities for the gprMax documentation, enabling intelligent context retrieval for the chatbot.
5
+
6
+ ## Architecture
7
+
8
+ ### Components
9
+ 1. **Document Processor**: Extracts and chunks documentation from gprMax GitHub repository
10
+ 2. **Embedding Model**: Qwen2.5-0.5B (will upgrade to Qwen3-Embedding-0.6B when available)
11
+ 3. **Vector Database**: ChromaDB with persistent storage
12
+ 4. **Retriever**: Search and context retrieval utilities
13
+
14
+ ### Key Features
15
+ - Automatic documentation extraction from gprMax GitHub repository
16
+ - Intelligent chunking with configurable size and overlap
17
+ - Persistent vector database using ChromaDB
18
+ - Efficient similarity search with score thresholding
19
+ - Metadata tracking for reproducibility
20
+
21
+ ## Installation
22
+
23
+ The database is **automatically generated** on first startup of the application. No manual installation required!
24
+
25
+ ## Automatic Generation
26
+
27
+ When the app starts:
28
+ 1. Checks if database exists at `rag-db/chroma_db/`
29
+ 2. If not found, automatically runs `generate_db.py`
30
+ 3. Clones gprMax repository and processes documentation
31
+ 4. Creates ChromaDB with default embeddings (all-MiniLM-L6-v2)
32
+ 5. Ready to use - this only happens once!
33
+
34
+ ## Manual Generation (Optional)
35
+
36
+ If you need to manually regenerate the database:
37
+
38
+ ```bash
39
+ cd rag-db
40
+ python generate_db.py --recreate
41
+ ```
42
+
43
+ Custom settings:
44
+ ```bash
45
+ python generate_db.py \
46
+ --db-path ./custom_db \
47
+ --temp-dir ./temp \
48
+ --device cuda \
49
+ --recreate
50
+ ```
51
+
52
+ ### 2. Use Retriever in Application
53
+
54
+ ```python
55
+ from rag_db.retriever import create_retriever
56
+
57
+ # Initialize retriever
58
+ retriever = create_retriever(db_path="./rag-db/chroma_db")
59
+
60
+ # Search for relevant documents
61
+ results = retriever.search("How to create a source?", k=5)
62
+
63
+ # Get formatted context for LLM
64
+ context = retriever.get_context("antenna patterns", k=3)
65
+
66
+ # Get relevant source files
67
+ files = retriever.get_relevant_files("boundary conditions")
68
+
69
+ # Get database statistics
70
+ stats = retriever.get_stats()
71
+ ```
72
+
73
+ ### 3. Test Retriever
74
+
75
+ ```bash
76
+ # Test with default query
77
+ python retriever.py
78
+
79
+ # Test with custom query
80
+ python retriever.py "How to model soil layers?"
81
+ ```
82
+
83
+ ## Database Schema
84
+
85
+ ### Document Structure
86
+ ```json
87
+ {
88
+ "id": "unique_hash",
89
+ "text": "document_chunk_text",
90
+ "metadata": {
91
+ "source": "docs/relative/path.rst",
92
+ "file_type": ".rst",
93
+ "chunk_index": 0,
94
+ "char_start": 0,
95
+ "char_end": 1000
96
+ }
97
+ }
98
+ ```
99
+
100
+ ### Metadata File
101
+ Generated `metadata.json` contains:
102
+ ```json
103
+ {
104
+ "created_at": "2024-01-01T00:00:00",
105
+ "embedding_model": "Qwen/Qwen2.5-0.5B",
106
+ "collection_name": "gprmax_docs_v1",
107
+ "chunk_size": 1000,
108
+ "chunk_overlap": 200,
109
+ "total_documents": 1234
110
+ }
111
+ ```
112
+
113
+ ## Configuration
114
+
115
+ ### Chunking Parameters
116
+ - `CHUNK_SIZE`: 1000 characters (optimal for context windows)
117
+ - `CHUNK_OVERLAP`: 200 characters (ensures continuity)
118
+
119
+ ### Embedding Model
120
+ - Current: `Qwen/Qwen2.5-0.5B` (512-dim embeddings)
121
+ - Future: `Qwen/Qwen3-Embedding-0.6B` (when available)
122
+
123
+ ### Database Settings
124
+ - Storage: ChromaDB persistent client
125
+ - Collection: `gprmax_docs_v1` (versioned for updates)
126
+ - Distance Metric: Cosine similarity
127
+
128
+ ## Maintenance
129
+
130
+ ### Regular Updates
131
+ Run monthly or when gprMax documentation updates:
132
+ ```bash
133
+ # This will pull latest docs and update database
134
+ python generate_db.py
135
+ ```
136
+
137
+ ### Database Backup
138
+ ```bash
139
+ # Backup database
140
+ cp -r chroma_db chroma_db_backup_$(date +%Y%m%d)
141
+ ```
142
+
143
+ ### Performance Tuning
144
+ - Adjust `CHUNK_SIZE` and `CHUNK_OVERLAP` in `generate_db.py`
145
+ - Modify batch sizes for large datasets
146
+ - Use GPU acceleration with `--device cuda`
147
+
148
+ ## Integration with Main App
149
+
150
+ The RAG system integrates with the main Gradio app:
151
+
152
+ 1. Import retriever in `app.py`
153
+ 2. Use retriever to augment prompts with context
154
+ 3. Display source references in UI
155
+
156
+ Example integration:
157
+ ```python
158
+ # In app.py
159
+ from rag_db.retriever import create_retriever
160
+
161
+ retriever = create_retriever()
162
+
163
+ def augment_with_context(user_query):
164
+ context = retriever.get_context(user_query, k=3)
165
+ augmented_prompt = f"""
166
+ Context from documentation:
167
+ {context}
168
+
169
+ User question: {user_query}
170
+ """
171
+ return augmented_prompt
172
+ ```
173
+
174
+ ## Troubleshooting
175
+
176
+ ### Common Issues
177
+
178
+ 1. **Database not found**
179
+ - Run `python generate_db.py` first
180
+ - Check `--db-path` parameter
181
+
182
+ 2. **Out of memory**
183
+ - Use smaller batch sizes
184
+ - Use CPU instead of GPU
185
+ - Reduce chunk size
186
+
187
+ 3. **Slow generation**
188
+ - Use GPU with `--device cuda`
189
+ - Reduce repository depth with shallow clone
190
+ - Use pre-generated database
191
+
192
+ ### Logs
193
+ Check generation logs for detailed information:
194
+ ```bash
195
+ python generate_db.py 2>&1 | tee generation.log
196
+ ```
197
+
198
+ ## Future Enhancements
199
+
200
+ 1. **Model Upgrade**: Migrate to Qwen3-Embedding-0.6B when available
201
+ 2. **Incremental Updates**: Add documents without full regeneration
202
+ 3. **Multi-modal Support**: Include images and diagrams from docs
203
+ 4. **Query Expansion**: Automatic query reformulation for better retrieval
204
+ 5. **Caching Layer**: Redis cache for frequent queries
205
+ 6. **Fine-tuned Embeddings**: Domain-specific embedding model for gprMax
206
+
207
+ ## License
208
+ Same as parent project
rag-db/__init__.py ADDED
File without changes
rag-db/generate_db.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RAG Database Generation Script for gprMax Documentation
4
+ Generates a ChromaDB vector database from gprMax documentation
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import shutil
10
+ import argparse
11
+ import logging
12
+ from pathlib import Path
13
+ from datetime import datetime
14
+ from typing import List, Dict, Any
15
+ import json
16
+ import hashlib
17
+
18
+ import chromadb
19
+ import git
20
+ from tqdm import tqdm
21
+
22
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class GprMaxDocumentProcessor:
27
+ """Process gprMax documentation files for vectorization"""
28
+
29
+ SUPPORTED_EXTENSIONS = {'.rst', '.md', '.txt'}
30
+ CHUNK_SIZE = 1000 # Characters per chunk
31
+ CHUNK_OVERLAP = 200 # Overlap between chunks
32
+
33
+ def __init__(self, repo_path: Path):
34
+ self.repo_path = repo_path
35
+ self.doc_path = repo_path / "docs"
36
+
37
+ def extract_documents(self) -> List[Dict[str, Any]]:
38
+ """Extract and chunk all documentation files"""
39
+ documents = []
40
+
41
+ if not self.doc_path.exists():
42
+ logger.warning(f"Documentation path {self.doc_path} does not exist")
43
+ return documents
44
+
45
+ for file_path in self._find_doc_files():
46
+ try:
47
+ chunks = self._process_file(file_path)
48
+ documents.extend(chunks)
49
+ except Exception as e:
50
+ logger.error(f"Error processing {file_path}: {e}")
51
+
52
+ logger.info(f"Extracted {len(documents)} document chunks")
53
+ return documents
54
+
55
+ def _find_doc_files(self) -> List[Path]:
56
+ """Find all documentation files"""
57
+ doc_files = []
58
+ for ext in self.SUPPORTED_EXTENSIONS:
59
+ doc_files.extend(self.doc_path.rglob(f"*{ext}"))
60
+ return doc_files
61
+
62
+ def _process_file(self, file_path: Path) -> List[Dict[str, Any]]:
63
+ """Process a single file into chunks"""
64
+ with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
65
+ content = f.read()
66
+
67
+ # Calculate relative path for metadata
68
+ rel_path = file_path.relative_to(self.repo_path)
69
+
70
+ # Create chunks with overlap
71
+ chunks = []
72
+ for i in range(0, len(content), self.CHUNK_SIZE - self.CHUNK_OVERLAP):
73
+ chunk_text = content[i:i + self.CHUNK_SIZE]
74
+
75
+ # Skip empty or very small chunks
76
+ if len(chunk_text.strip()) < 50:
77
+ continue
78
+
79
+ # Generate unique ID for chunk
80
+ chunk_id = hashlib.md5(f"{rel_path}_{i}_{chunk_text[:50]}".encode()).hexdigest()
81
+
82
+ chunks.append({
83
+ "id": chunk_id,
84
+ "text": chunk_text,
85
+ "metadata": {
86
+ "source": str(rel_path),
87
+ "file_type": file_path.suffix,
88
+ "chunk_index": len(chunks),
89
+ "char_start": i,
90
+ "char_end": min(i + self.CHUNK_SIZE, len(content))
91
+ }
92
+ })
93
+
94
+ return chunks
95
+
96
+
97
+ # Removed custom embedding model - using ChromaDB's default
98
+
99
+
100
+ class ChromaRAGDatabase:
101
+ """ChromaDB-based RAG database"""
102
+
103
+ def __init__(self, db_path: Path):
104
+ self.db_path = db_path
105
+
106
+ # Initialize ChromaDB with persistent storage
107
+ self.client = chromadb.PersistentClient(path=str(db_path))
108
+
109
+ # Collection name with version for easy updates
110
+ self.collection_name = "gprmax_docs_v1"
111
+
112
+ def create_collection(self, recreate: bool = False):
113
+ """Create or get the document collection"""
114
+ if recreate:
115
+ try:
116
+ self.client.delete_collection(self.collection_name)
117
+ logger.info(f"Deleted existing collection: {self.collection_name}")
118
+ except:
119
+ pass
120
+
121
+ # Let ChromaDB use its default embedding function
122
+ self.collection = self.client.create_collection(
123
+ name=self.collection_name,
124
+ metadata={"created_at": datetime.now().isoformat()}
125
+ )
126
+ logger.info(f"Created collection: {self.collection_name}")
127
+
128
+ def add_documents(self, documents: List[Dict[str, Any]]):
129
+ """Add documents to the collection"""
130
+ if not documents:
131
+ logger.warning("No documents to add")
132
+ return
133
+
134
+ # Prepare data for ChromaDB
135
+ ids = [doc["id"] for doc in documents]
136
+ texts = [doc["text"] for doc in documents]
137
+ metadatas = [doc["metadata"] for doc in documents]
138
+
139
+ # Add to collection in batches (ChromaDB will generate embeddings automatically)
140
+ batch_size = 100
141
+ logger.info(f"Adding {len(documents)} documents to database...")
142
+ for i in tqdm(range(0, len(ids), batch_size), desc="Adding to database"):
143
+ end_idx = min(i + batch_size, len(ids))
144
+ self.collection.add(
145
+ ids=ids[i:end_idx],
146
+ documents=texts[i:end_idx],
147
+ metadatas=metadatas[i:end_idx]
148
+ # No embeddings parameter - ChromaDB will generate them
149
+ )
150
+
151
+ logger.info(f"Added {len(documents)} documents to database")
152
+
153
+ # Verify documents were added
154
+ actual_count = self.collection.count()
155
+ logger.info(f"Verified collection now contains {actual_count} documents")
156
+
157
+ def save_metadata(self):
158
+ """Save database metadata for reference"""
159
+ # Get fresh count
160
+ doc_count = self.collection.count()
161
+
162
+ metadata = {
163
+ "created_at": datetime.now().isoformat(),
164
+ "embedding_model": "ChromaDB Default (all-MiniLM-L6-v2)",
165
+ "collection_name": self.collection_name,
166
+ "chunk_size": GprMaxDocumentProcessor.CHUNK_SIZE,
167
+ "chunk_overlap": GprMaxDocumentProcessor.CHUNK_OVERLAP,
168
+ "total_documents": doc_count
169
+ }
170
+
171
+ metadata_path = self.db_path / "metadata.json"
172
+ with open(metadata_path, 'w') as f:
173
+ json.dump(metadata, f, indent=2)
174
+
175
+ logger.info(f"Saved metadata to {metadata_path}")
176
+
177
+
178
+ def clone_gprmax_repo(target_dir: Path) -> Path:
179
+ """Clone or update gprMax repository"""
180
+ repo_path = target_dir / "gprMax"
181
+
182
+ if repo_path.exists():
183
+ logger.info(f"Updating existing repository at {repo_path}")
184
+ repo = git.Repo(repo_path)
185
+ repo.remotes.origin.pull()
186
+ else:
187
+ logger.info(f"Cloning gprMax repository to {repo_path}")
188
+ git.Repo.clone_from(
189
+ "https://github.com/gprMax/gprMax.git",
190
+ repo_path,
191
+ depth=1 # Shallow clone for faster download
192
+ )
193
+
194
+ return repo_path
195
+
196
+
197
+ def main():
198
+ parser = argparse.ArgumentParser(description="Generate RAG database from gprMax documentation")
199
+ parser.add_argument(
200
+ "--db-path",
201
+ type=Path,
202
+ default=Path(__file__).parent / "chroma_db",
203
+ help="Path to store the ChromaDB database"
204
+ )
205
+ parser.add_argument(
206
+ "--temp-dir",
207
+ type=Path,
208
+ default=Path(__file__).parent / "temp",
209
+ help="Temporary directory for cloning repository"
210
+ )
211
+ parser.add_argument(
212
+ "--recreate",
213
+ action="store_true",
214
+ help="Recreate database from scratch (delete existing)"
215
+ )
216
+
217
+ args = parser.parse_args()
218
+
219
+ try:
220
+ # Step 1: Clone/update gprMax repository
221
+ logger.info("Step 1: Fetching gprMax repository...")
222
+ repo_path = clone_gprmax_repo(args.temp_dir)
223
+
224
+ # Step 2: Process documentation
225
+ logger.info("Step 2: Processing documentation files...")
226
+ processor = GprMaxDocumentProcessor(repo_path)
227
+ documents = processor.extract_documents()
228
+
229
+ if not documents:
230
+ logger.error("No documents found to process")
231
+ return 1
232
+
233
+ # Step 3: Create database
234
+ logger.info("Step 3: Creating vector database...")
235
+ db = ChromaRAGDatabase(args.db_path)
236
+ db.create_collection(recreate=args.recreate)
237
+
238
+ # Step 4: Add documents
239
+ logger.info("Step 4: Adding documents to database...")
240
+ db.add_documents(documents)
241
+
242
+ # Step 5: Save metadata
243
+ db.save_metadata()
244
+
245
+ logger.info(f"✅ Database successfully created at {args.db_path}")
246
+ logger.info(f"Total documents: {len(documents)}")
247
+
248
+ # Cleanup temp files if needed
249
+ if args.temp_dir.exists() and args.temp_dir != args.db_path.parent:
250
+ logger.info("Cleaning up temporary files...")
251
+ shutil.rmtree(args.temp_dir, ignore_errors=True)
252
+
253
+ return 0
254
+
255
+ except Exception as e:
256
+ logger.error(f"Failed to generate database: {e}")
257
+ return 1
258
+
259
+
260
+ if __name__ == "__main__":
261
+ sys.exit(main())
rag-db/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # RAG Database Requirements
2
+ chromadb>=0.4.22
3
+ GitPython>=3.1.40
4
+ tqdm>=4.66.1
5
+ torch>=2.0.0
6
+ transformers>=4.44.0
7
+ sentencepiece
rag-db/retriever.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG Retrieval Utilities for gprMax Documentation
3
+ Provides search and retrieval functions for the vector database
4
+ """
5
+
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import List, Dict, Any, Optional, Tuple
9
+ import json
10
+
11
+ import chromadb
12
+ from dataclasses import dataclass
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class SearchResult:
19
+ """Container for search results"""
20
+ text: str
21
+ score: float
22
+ metadata: Dict[str, Any]
23
+
24
+ def __str__(self) -> str:
25
+ return f"[Score: {self.score:.3f}] {self.metadata.get('source', 'Unknown')}: {self.text[:100]}..."
26
+
27
+
28
+ # Removed QwenEmbeddingModel class - using ChromaDB's default embedding
29
+
30
+
31
+ class GprMaxRAGRetriever:
32
+ """Retriever for gprMax documentation RAG database"""
33
+
34
+ def __init__(self, db_path: Path = None):
35
+ if db_path is None:
36
+ db_path = Path(__file__).parent / "chroma_db"
37
+
38
+ if not db_path.exists():
39
+ raise ValueError(f"Database path {db_path} does not exist. Run generate_db.py first.")
40
+
41
+ self.db_path = db_path
42
+
43
+ # Load metadata
44
+ metadata_path = db_path / "metadata.json"
45
+ if metadata_path.exists():
46
+ with open(metadata_path, 'r') as f:
47
+ self.metadata = json.load(f)
48
+ else:
49
+ self.metadata = {}
50
+
51
+ # Initialize ChromaDB client
52
+ self.client = chromadb.PersistentClient(path=str(db_path))
53
+
54
+ # Get collection
55
+ self.collection_name = self.metadata.get("collection_name", "gprmax_docs_v1")
56
+ try:
57
+ print(f"[RAG] Loading collection: {self.collection_name}")
58
+ self.collection = self.client.get_collection(self.collection_name)
59
+ doc_count = self.collection.count()
60
+ print(f"[RAG] Loaded collection: {self.collection_name} with {doc_count} documents")
61
+ logger.info(f"Loaded collection: {self.collection_name} with {doc_count} documents")
62
+ except Exception as e:
63
+ print(f"[RAG] ERROR loading collection: {e}")
64
+ raise ValueError(f"Failed to load collection {self.collection_name}: {e}")
65
+
66
+ def search(
67
+ self,
68
+ query: str,
69
+ k: int = 10,
70
+ threshold: float = 0.0,
71
+ filter_metadata: Optional[Dict[str, Any]] = None
72
+ ) -> List[SearchResult]:
73
+ """
74
+ Search for relevant documents
75
+
76
+ Args:
77
+ query: Search query text
78
+ k: Number of results to return
79
+ threshold: Minimum similarity score threshold
80
+ filter_metadata: Optional metadata filters
81
+
82
+ Returns:
83
+ List of SearchResult objects
84
+ """
85
+ # Search in ChromaDB (it will generate embeddings automatically)
86
+ try:
87
+ results = self.collection.query(
88
+ query_texts=[query], # Use query_texts instead of query_embeddings
89
+ n_results=k,
90
+ where=filter_metadata if filter_metadata else None,
91
+ include=["documents", "metadatas", "distances"]
92
+ )
93
+ logger.info(f"ChromaDB query returned: {len(results.get('documents', [[]])[0]) if results.get('documents') else 0} results")
94
+ except Exception as e:
95
+ logger.error(f"ChromaDB query failed: {e}")
96
+ raise
97
+
98
+ # Convert to SearchResult objects
99
+ search_results = []
100
+ if results["documents"] and results["documents"][0]:
101
+ for doc, meta, dist in zip(
102
+ results["documents"][0],
103
+ results["metadatas"][0],
104
+ results["distances"][0]
105
+ ):
106
+ # Convert distance to similarity score (1 - normalized_distance)
107
+ score = 1.0 - (dist / 2.0) # Assuming cosine distance in [-1, 1]
108
+
109
+ if score >= threshold:
110
+ search_results.append(SearchResult(
111
+ text=doc,
112
+ score=score,
113
+ metadata=meta
114
+ ))
115
+
116
+ return search_results
117
+
118
+ def get_context(
119
+ self,
120
+ query: str,
121
+ k: int = 3,
122
+ max_context_length: int = 2000,
123
+ format_as_markdown: bool = True
124
+ ) -> str:
125
+ """
126
+ Get formatted context for a query
127
+
128
+ Args:
129
+ query: Search query
130
+ k: Number of documents to retrieve
131
+ max_context_length: Maximum total context length
132
+ format_as_markdown: Format output as markdown
133
+
134
+ Returns:
135
+ Formatted context string
136
+ """
137
+ results = self.search(query, k=k)
138
+
139
+ if not results:
140
+ return "No relevant documentation found."
141
+
142
+ context_parts = []
143
+ total_length = 0
144
+
145
+ for i, result in enumerate(results, 1):
146
+ if total_length >= max_context_length:
147
+ break
148
+
149
+ # Truncate if needed
150
+ text = result.text
151
+ if total_length + len(text) > max_context_length:
152
+ text = text[:max_context_length - total_length]
153
+
154
+ if format_as_markdown:
155
+ source = result.metadata.get("source", "Unknown")
156
+ context_parts.append(
157
+ f"### Document {i} (Source: {source}, Score: {result.score:.3f})\n"
158
+ f"```\n{text}\n```\n"
159
+ )
160
+ else:
161
+ context_parts.append(text)
162
+
163
+ total_length += len(text)
164
+
165
+ return "\n".join(context_parts)
166
+
167
+ def get_relevant_files(self, query: str, k: int = 5) -> List[str]:
168
+ """Get list of relevant source files for a query"""
169
+ results = self.search(query, k=k)
170
+
171
+ # Extract unique source files
172
+ sources = set()
173
+ for result in results:
174
+ source = result.metadata.get("source")
175
+ if source:
176
+ sources.add(source)
177
+
178
+ return sorted(list(sources))
179
+
180
+ def search_by_file(self, file_pattern: str, k: int = 10) -> List[SearchResult]:
181
+ """Search for documents from specific files"""
182
+ # This would need ChromaDB's where clause with pattern matching
183
+ # For now, we do a broad search and filter
184
+ results = self.collection.get(
185
+ limit=1000, # Get many results
186
+ include=["documents", "metadatas"]
187
+ )
188
+
189
+ filtered_results = []
190
+ if results["documents"]:
191
+ for doc, meta in zip(results["documents"], results["metadatas"]):
192
+ source = meta.get("source", "")
193
+ if file_pattern.lower() in source.lower():
194
+ filtered_results.append(SearchResult(
195
+ text=doc,
196
+ score=1.0, # No score for direct retrieval
197
+ metadata=meta
198
+ ))
199
+
200
+ if len(filtered_results) >= k:
201
+ break
202
+
203
+ return filtered_results
204
+
205
+ def get_stats(self) -> Dict[str, Any]:
206
+ """Get database statistics"""
207
+ stats = {
208
+ "total_documents": self.collection.count(),
209
+ "database_path": str(self.db_path),
210
+ "collection_name": self.collection_name,
211
+ "embedding_model": self.metadata.get("embedding_model", "Unknown"),
212
+ "created_at": self.metadata.get("created_at", "Unknown"),
213
+ "chunk_size": self.metadata.get("chunk_size", "Unknown"),
214
+ "chunk_overlap": self.metadata.get("chunk_overlap", "Unknown")
215
+ }
216
+ return stats
217
+
218
+
219
+ def create_retriever(db_path: Optional[Path] = None) -> GprMaxRAGRetriever:
220
+ """Factory function to create a retriever instance"""
221
+ return GprMaxRAGRetriever(db_path=db_path)
222
+
223
+
224
+ if __name__ == "__main__":
225
+ # Example usage
226
+ import sys
227
+
228
+ if len(sys.argv) > 1:
229
+ query = " ".join(sys.argv[1:])
230
+ else:
231
+ query = "How to create a source in gprMax?"
232
+
233
+ print(f"Testing retriever with query: '{query}'")
234
+ print("-" * 80)
235
+
236
+ try:
237
+ retriever = create_retriever()
238
+
239
+ # Get stats
240
+ stats = retriever.get_stats()
241
+ print(f"Database stats: {stats}")
242
+ print("-" * 80)
243
+
244
+ # Search
245
+ results = retriever.search(query, k=3)
246
+ print(f"Found {len(results)} results:")
247
+ for i, result in enumerate(results, 1):
248
+ print(f"\n{i}. {result}")
249
+
250
+ # Get formatted context
251
+ print("\n" + "=" * 80)
252
+ print("Formatted context:")
253
+ print(retriever.get_context(query, k=3))
254
+
255
+ except Exception as e:
256
+ print(f"Error: {e}")
257
+ sys.exit(1)
requirements.txt CHANGED
@@ -4,4 +4,7 @@ spaces
4
  accelerate
5
  sentencepiece
6
  einops
7
- numpy < 2.0.0
 
 
 
 
4
  accelerate
5
  sentencepiece
6
  einops
7
+ numpy < 2.0.0
8
+ chromadb
9
+ GitPython
10
+ tqdm