mshabir commited on
Commit
e13bd2f
·
verified ·
1 Parent(s): 12f26a4

Update medical_rag_system.py

Browse files
Files changed (1) hide show
  1. medical_rag_system.py +68 -10
medical_rag_system.py CHANGED
@@ -1,13 +1,15 @@
1
  import faiss
2
  import pickle
3
- from sentence_transformers import SentenceTransformer
4
  import numpy as np
5
  import os
 
 
 
 
6
 
7
  class MedicalRAGSystem:
8
  def __init__(self, vector_store_path="medical_rag/vector_store"):
9
  self.vector_store_path = vector_store_path
10
- self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
11
  self.load_vector_store()
12
 
13
  def load_vector_store(self):
@@ -18,7 +20,7 @@ class MedicalRAGSystem:
18
  self.vector_store_path,
19
  "medical_rag/vector_store",
20
  "./medical_rag/vector_store",
21
- "/app/medical_rag/vector_store", # For Hugging Face Spaces
22
  ]
23
 
24
  loaded = False
@@ -31,6 +33,15 @@ class MedicalRAGSystem:
31
  self.index = faiss.read_index(index_path)
32
  with open(metadata_path, "rb") as f:
33
  data = pickle.load(f)
 
 
 
 
 
 
 
 
 
34
  self.chunks = data['chunks']
35
  self.metadata = data['metadata']
36
  print(f"✅ Vector store loaded from: {path}")
@@ -41,23 +52,70 @@ class MedicalRAGSystem:
41
  continue
42
 
43
  if not loaded:
44
- raise Exception("Could not load vector store from any path. Check if vector store files exist.")
45
 
46
  except Exception as e:
47
  raise Exception(f"Error loading vector store: {str(e)}")
48
 
49
- def retrieve_similar_chunks(self, query, k=5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  """Retrieve similar medical chunks for a query"""
51
  try:
52
- # Encode query
53
- query_embedding = self.embedding_model.encode([query])
 
 
 
54
  faiss.normalize_L2(query_embedding)
55
 
56
- # Search (get extra for potential filtering)
57
- scores, indices = self.index.search(query_embedding, k*3)
58
 
59
  results = []
60
- seen_chunks = set() # Avoid duplicates
61
 
62
  for score, idx in zip(scores[0], indices[0]):
63
  if idx < len(self.chunks) and idx not in seen_chunks:
 
1
  import faiss
2
  import pickle
 
3
  import numpy as np
4
  import os
5
+ import requests
6
+ import json
7
+ from typing import List, Dict, Any
8
+ import base64
9
 
10
  class MedicalRAGSystem:
11
  def __init__(self, vector_store_path="medical_rag/vector_store"):
12
  self.vector_store_path = vector_store_path
 
13
  self.load_vector_store()
14
 
15
  def load_vector_store(self):
 
20
  self.vector_store_path,
21
  "medical_rag/vector_store",
22
  "./medical_rag/vector_store",
23
+ "/app/medical_rag/vector_store",
24
  ]
25
 
26
  loaded = False
 
33
  self.index = faiss.read_index(index_path)
34
  with open(metadata_path, "rb") as f:
35
  data = pickle.load(f)
36
+
37
+ # Load pre-computed embeddings
38
+ embeddings_path = f"{path}/embeddings.npy"
39
+ if os.path.exists(embeddings_path):
40
+ self.embeddings = np.load(embeddings_path)
41
+ else:
42
+ # If no embeddings file, we'll use the index directly
43
+ self.embeddings = None
44
+
45
  self.chunks = data['chunks']
46
  self.metadata = data['metadata']
47
  print(f"✅ Vector store loaded from: {path}")
 
52
  continue
53
 
54
  if not loaded:
55
+ raise Exception("Could not load vector store from any path.")
56
 
57
  except Exception as e:
58
  raise Exception(f"Error loading vector store: {str(e)}")
59
 
60
+ def get_embedding_huggingface(self, text: str) -> np.ndarray:
61
+ """Get embedding using Hugging Face Inference API"""
62
+ try:
63
+ API_URL = "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/all-MiniLM-L6-v2"
64
+ headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN', '')}"}
65
+
66
+ response = requests.post(
67
+ API_URL,
68
+ headers=headers,
69
+ json={"inputs": text, "options": {"wait_for_model": True}}
70
+ )
71
+
72
+ if response.status_code == 200:
73
+ return np.array(response.json())
74
+ else:
75
+ # Fallback to simple TF-IDF like approach
76
+ print(f"Hugging Face API failed: {response.status_code}")
77
+ return self.simple_embedding(text)
78
+
79
+ except Exception as e:
80
+ print(f"Embedding error: {e}")
81
+ return self.simple_embedding(text)
82
+
83
+ def simple_embedding(self, text: str) -> np.ndarray:
84
+ """Simple fallback embedding (TF-IDF like)"""
85
+ # Convert text to lowercase and split
86
+ words = text.lower().split()
87
+
88
+ # Create a simple bag-of-words vector (384 dimensions to match MiniLM)
89
+ embedding = np.zeros(384)
90
+
91
+ # Simple hash-based word distribution
92
+ for word in words:
93
+ # Create a simple hash-based position
94
+ hash_val = hash(word) % 384
95
+ embedding[hash_val] += 1
96
+
97
+ # Normalize
98
+ norm = np.linalg.norm(embedding)
99
+ if norm > 0:
100
+ embedding = embedding / norm
101
+
102
+ return embedding
103
+
104
+ def retrieve_similar_chunks(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
105
  """Retrieve similar medical chunks for a query"""
106
  try:
107
+ # Get query embedding
108
+ query_embedding = self.get_embedding_huggingface(query).astype('float32')
109
+ query_embedding = query_embedding.reshape(1, -1)
110
+
111
+ # Normalize for cosine similarity
112
  faiss.normalize_L2(query_embedding)
113
 
114
+ # Search
115
+ scores, indices = self.index.search(query_embedding, min(k * 3, len(self.chunks)))
116
 
117
  results = []
118
+ seen_chunks = set()
119
 
120
  for score, idx in zip(scores[0], indices[0]):
121
  if idx < len(self.chunks) and idx not in seen_chunks: