MedSpace / scripts /generate_system_design.py
kbsss's picture
Upload folder using huggingface_hub
f373e2b verified
Raw
History Blame Contribute Delete
3.26 kB
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
def draw_box(ax, x, y, w, h, text, color='#E0E0E0', edge='black', fontweight='normal'):
rect = patches.Rectangle((x, y), w, h, linewidth=1.5, edgecolor=edge, facecolor=color, zorder=3)
ax.add_patch(rect)
ax.text(x + w/2, y + h/2, text, ha='center', va='center', fontsize=10, zorder=4, wrap=True, fontweight=fontweight)
return x, y, w, h
def draw_arrow(ax, x1, y1, x2, y2, text=None, color='black'):
ax.annotate("", xy=(x2, y2), xytext=(x1, y1),
arrowprops=dict(arrowstyle="->", lw=1.5, color=color), zorder=5)
if text:
ax.text((x1+x2)/2, (y1+y2)/2, text, ha='center', va='bottom', fontsize=8, color='#333333',
bbox=dict(facecolor='white', edgecolor='none', alpha=0.8), zorder=6)
def generate_diagram():
fig, ax = plt.subplots(figsize=(14, 10))
ax.set_xlim(0, 14)
ax.set_ylim(0, 10)
ax.axis('off')
# Title
ax.text(7, 9.5, "Advanced Medical RAG System Design", ha='center', fontsize=18, weight='bold')
# External User
draw_box(ax, 0.5, 7.5, 2, 1, "Medical Professional\n(User)", color='#FFE4B5', fontweight='bold')
# Frontend
draw_box(ax, 3.5, 7.5, 2.5, 1, "Frontend\n(Streamlit)", color='#ADD8E6')
draw_arrow(ax, 2.5, 8, 3.5, 8, "Query / UI")
# API
draw_box(ax, 7, 7.5, 2.5, 1, "Backend API\n(FastAPI)", color='#98FB98')
draw_arrow(ax, 6, 8, 7, 8, "REST/JSON")
# Pipeline Boundary (Large Box)
pipeline_rect = patches.Rectangle((3, 1), 10, 5.5, linewidth=2, edgecolor='#555555', facecolor='#FAFAFA', linestyle='--', zorder=1)
ax.add_patch(pipeline_rect)
ax.text(8, 6.2, "Healthcare RAG Pipeline (Core Orchestrator)", ha='center', fontsize=12, weight='bold', color='#444444')
# Retrieval Stage
draw_box(ax, 3.5, 4.5, 2.5, 1, "Hybrid Retriever\n(Dense + Sparse)", color='#FFD700')
draw_arrow(ax, 8.25, 7.5, 4.75, 5.5, "1. Retrieve")
# Data Sources
draw_box(ax, 3.5, 1.5, 1.1, 0.8, "ChromaDB\n(Dense)", color='#F0E68C')
draw_box(ax, 4.9, 1.5, 1.1, 0.8, "BM25\n(Sparse)", color='#F0E68C')
draw_arrow(ax, 4.05, 4.5, 4.05, 2.3)
draw_arrow(ax, 5.45, 4.5, 5.45, 2.3)
# Reranker & Grounding
draw_box(ax, 7, 4.5, 2.5, 1, "Reranker &\nGrounding Gate", color='#FF6347')
draw_arrow(ax, 6, 5, 7, 5, "2. Refine")
# Generation
draw_box(ax, 10, 4.5, 2.5, 1, "Medical LLM\n(BioMistral/TinyLlama)", color='#DDA0DD')
draw_arrow(ax, 9.5, 5, 10, 5, "3. Generate")
# XAI Module
draw_box(ax, 10, 2.5, 2.5, 1, "XAI Module\n(Explainability)", color='#87CEFA')
draw_arrow(ax, 11.25, 4.5, 11.25, 3.5, "4. Explain")
# Final Response
draw_arrow(ax, 10, 3, 8.25, 7.5, "5. Response + XAI", color='blue')
# Legend / Info
info_text = "Key Features:\n• Hybrid RRF Retrieval\n• Grounding Check (Anti-Hallucination)\n• Source Attribution\n• Confidence Scoring"
ax.text(0.5, 0.5, info_text, fontsize=10, bbox=dict(facecolor='white', alpha=0.5))
output_path = "docs/architecture/system_design.png"
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"Diagram saved to {output_path}")
if __name__ == "__main__":
generate_diagram()