Spaces:
Sleeping
Sleeping
Update medical_rag_system.py
Browse files- medical_rag_system.py +68 -10
medical_rag_system.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
| 1 |
import faiss
|
| 2 |
import pickle
|
| 3 |
-
from sentence_transformers import SentenceTransformer
|
| 4 |
import numpy as np
|
| 5 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
class MedicalRAGSystem:
|
| 8 |
def __init__(self, vector_store_path="medical_rag/vector_store"):
|
| 9 |
self.vector_store_path = vector_store_path
|
| 10 |
-
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 11 |
self.load_vector_store()
|
| 12 |
|
| 13 |
def load_vector_store(self):
|
|
@@ -18,7 +20,7 @@ class MedicalRAGSystem:
|
|
| 18 |
self.vector_store_path,
|
| 19 |
"medical_rag/vector_store",
|
| 20 |
"./medical_rag/vector_store",
|
| 21 |
-
"/app/medical_rag/vector_store",
|
| 22 |
]
|
| 23 |
|
| 24 |
loaded = False
|
|
@@ -31,6 +33,15 @@ class MedicalRAGSystem:
|
|
| 31 |
self.index = faiss.read_index(index_path)
|
| 32 |
with open(metadata_path, "rb") as f:
|
| 33 |
data = pickle.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
self.chunks = data['chunks']
|
| 35 |
self.metadata = data['metadata']
|
| 36 |
print(f"✅ Vector store loaded from: {path}")
|
|
@@ -41,23 +52,70 @@ class MedicalRAGSystem:
|
|
| 41 |
continue
|
| 42 |
|
| 43 |
if not loaded:
|
| 44 |
-
raise Exception("Could not load vector store from any path.
|
| 45 |
|
| 46 |
except Exception as e:
|
| 47 |
raise Exception(f"Error loading vector store: {str(e)}")
|
| 48 |
|
| 49 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
"""Retrieve similar medical chunks for a query"""
|
| 51 |
try:
|
| 52 |
-
#
|
| 53 |
-
query_embedding = self.
|
|
|
|
|
|
|
|
|
|
| 54 |
faiss.normalize_L2(query_embedding)
|
| 55 |
|
| 56 |
-
# Search
|
| 57 |
-
scores, indices = self.index.search(query_embedding, k*3)
|
| 58 |
|
| 59 |
results = []
|
| 60 |
-
seen_chunks = set()
|
| 61 |
|
| 62 |
for score, idx in zip(scores[0], indices[0]):
|
| 63 |
if idx < len(self.chunks) and idx not in seen_chunks:
|
|
|
|
| 1 |
import faiss
|
| 2 |
import pickle
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import os
|
| 5 |
+
import requests
|
| 6 |
+
import json
|
| 7 |
+
from typing import List, Dict, Any
|
| 8 |
+
import base64
|
| 9 |
|
| 10 |
class MedicalRAGSystem:
|
| 11 |
def __init__(self, vector_store_path="medical_rag/vector_store"):
|
| 12 |
self.vector_store_path = vector_store_path
|
|
|
|
| 13 |
self.load_vector_store()
|
| 14 |
|
| 15 |
def load_vector_store(self):
|
|
|
|
| 20 |
self.vector_store_path,
|
| 21 |
"medical_rag/vector_store",
|
| 22 |
"./medical_rag/vector_store",
|
| 23 |
+
"/app/medical_rag/vector_store",
|
| 24 |
]
|
| 25 |
|
| 26 |
loaded = False
|
|
|
|
| 33 |
self.index = faiss.read_index(index_path)
|
| 34 |
with open(metadata_path, "rb") as f:
|
| 35 |
data = pickle.load(f)
|
| 36 |
+
|
| 37 |
+
# Load pre-computed embeddings
|
| 38 |
+
embeddings_path = f"{path}/embeddings.npy"
|
| 39 |
+
if os.path.exists(embeddings_path):
|
| 40 |
+
self.embeddings = np.load(embeddings_path)
|
| 41 |
+
else:
|
| 42 |
+
# If no embeddings file, we'll use the index directly
|
| 43 |
+
self.embeddings = None
|
| 44 |
+
|
| 45 |
self.chunks = data['chunks']
|
| 46 |
self.metadata = data['metadata']
|
| 47 |
print(f"✅ Vector store loaded from: {path}")
|
|
|
|
| 52 |
continue
|
| 53 |
|
| 54 |
if not loaded:
|
| 55 |
+
raise Exception("Could not load vector store from any path.")
|
| 56 |
|
| 57 |
except Exception as e:
|
| 58 |
raise Exception(f"Error loading vector store: {str(e)}")
|
| 59 |
|
| 60 |
+
def get_embedding_huggingface(self, text: str) -> np.ndarray:
|
| 61 |
+
"""Get embedding using Hugging Face Inference API"""
|
| 62 |
+
try:
|
| 63 |
+
API_URL = "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/all-MiniLM-L6-v2"
|
| 64 |
+
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN', '')}"}
|
| 65 |
+
|
| 66 |
+
response = requests.post(
|
| 67 |
+
API_URL,
|
| 68 |
+
headers=headers,
|
| 69 |
+
json={"inputs": text, "options": {"wait_for_model": True}}
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if response.status_code == 200:
|
| 73 |
+
return np.array(response.json())
|
| 74 |
+
else:
|
| 75 |
+
# Fallback to simple TF-IDF like approach
|
| 76 |
+
print(f"Hugging Face API failed: {response.status_code}")
|
| 77 |
+
return self.simple_embedding(text)
|
| 78 |
+
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"Embedding error: {e}")
|
| 81 |
+
return self.simple_embedding(text)
|
| 82 |
+
|
| 83 |
+
def simple_embedding(self, text: str) -> np.ndarray:
|
| 84 |
+
"""Simple fallback embedding (TF-IDF like)"""
|
| 85 |
+
# Convert text to lowercase and split
|
| 86 |
+
words = text.lower().split()
|
| 87 |
+
|
| 88 |
+
# Create a simple bag-of-words vector (384 dimensions to match MiniLM)
|
| 89 |
+
embedding = np.zeros(384)
|
| 90 |
+
|
| 91 |
+
# Simple hash-based word distribution
|
| 92 |
+
for word in words:
|
| 93 |
+
# Create a simple hash-based position
|
| 94 |
+
hash_val = hash(word) % 384
|
| 95 |
+
embedding[hash_val] += 1
|
| 96 |
+
|
| 97 |
+
# Normalize
|
| 98 |
+
norm = np.linalg.norm(embedding)
|
| 99 |
+
if norm > 0:
|
| 100 |
+
embedding = embedding / norm
|
| 101 |
+
|
| 102 |
+
return embedding
|
| 103 |
+
|
| 104 |
+
def retrieve_similar_chunks(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
| 105 |
"""Retrieve similar medical chunks for a query"""
|
| 106 |
try:
|
| 107 |
+
# Get query embedding
|
| 108 |
+
query_embedding = self.get_embedding_huggingface(query).astype('float32')
|
| 109 |
+
query_embedding = query_embedding.reshape(1, -1)
|
| 110 |
+
|
| 111 |
+
# Normalize for cosine similarity
|
| 112 |
faiss.normalize_L2(query_embedding)
|
| 113 |
|
| 114 |
+
# Search
|
| 115 |
+
scores, indices = self.index.search(query_embedding, min(k * 3, len(self.chunks)))
|
| 116 |
|
| 117 |
results = []
|
| 118 |
+
seen_chunks = set()
|
| 119 |
|
| 120 |
for score, idx in zip(scores[0], indices[0]):
|
| 121 |
if idx < len(self.chunks) and idx not in seen_chunks:
|