Spaces:
Sleeping
Sleeping
File size: 5,423 Bytes
db33ebc a6f490e db33ebc a6f490e db33ebc 9409f90 db33ebc a6f490e db33ebc a6f490e db33ebc a6f490e db33ebc a6f490e db33ebc a6f490e db33ebc 9409f90 db33ebc 9409f90 db33ebc 9409f90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | import hashlib
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from document_classifier import DocumentClassifier
class RAG_Setup:
def __init__(self):
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
self.vector_store = Chroma(
collection_name="medical_history_collection",
embedding_function=self.embeddings,
persist_directory="data/patient_record_db",
)
self.classifier = DocumentClassifier(
pages_per_group=2,
min_confidence=0.35,
model_name="cross-encoder/nli-deberta-v3-small"
)
def _calculate_file_hash(self, file_path):
sha256 = hashlib.sha256()
with open(file_path, 'rb') as f:
while chunk := f.read(8192):
sha256.update(chunk)
return sha256.hexdigest()
def _is_file_uploaded(self, file_hash):
results = self.vector_store.get(
where={"file_hash": file_hash},
limit=1
)
return len(results['ids']) > 0
def _extract_content(self, file_path):
pdf_loader = PyPDFLoader(file_path)
content = pdf_loader.load()
return content
def _split_content(self, content):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, add_start_index=True)
chunks = text_splitter.split_documents(content)
return chunks
def _embed_content(self, chunks):
self.vector_store.add_documents(chunks)
def store_data(self, file_path, user_id=None):
file_hash = self._calculate_file_hash(file_path)
if self._is_file_uploaded(file_hash):
return {
"status": "skipped",
"message": f"File already exists in database"
}
try:
print(f"[RAG] Classifying document...")
classification = self.classifier.classify_document(file_path)
page_map = classification['page_classifications']
print(f"[RAG] Primary type: {classification['primary_type']}")
print(f"[RAG] Found {len(classification['all_types'])} document types "
f"across {classification['total_pages']} pages in {classification['processing_time']}s")
if len(classification['all_types']) > 1:
print(f"[RAG] All types: {', '.join(classification['all_types'])}")
loader = PyPDFLoader(file_path)
pages = loader.load()
all_chunks = []
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100,
add_start_index=True
)
for i, page in enumerate(pages):
page_num = i + 1
page_class = page_map.get(page_num, {'type': 'other', 'confidence': 0.0})
page_chunks = text_splitter.split_documents([page])
for chunk in page_chunks:
chunk.metadata.update({
'file_hash': file_hash,
'page_number': page_num,
'doc_type': page_class['type'],
'classification_confidence': page_class['confidence'],
'primary_doc_type': classification['primary_type'],
'all_doc_types': ','.join(classification['all_types'])
})
if user_id:
chunk.metadata['user_id'] = user_id
all_chunks.extend(page_chunks)
self._embed_content(all_chunks)
print(f"[RAG] Stored {len(all_chunks)} chunks with page-specific labels")
return {
"status": "success",
"message": f"File successfully uploaded",
"chunks": len(all_chunks),
"primary_type": classification['primary_type'],
"all_types": classification['all_types'],
"processing_time": classification['processing_time']
}
except Exception as e:
print(f"[RAG] Error: {e}")
import traceback
traceback.print_exc()
return {
"status": "error",
"message": f"Failed to upload file: {str(e)}"
}
def retrieve_info(self, user_id:str, query: str):
try:
print(f"[RAG] Retrieving for user_id: {user_id}, query: {query}")
results = self.vector_store.similarity_search(query, k=5, filter={"user_id": user_id})
print(f"[RAG] Found {len(results)} results")
if not results:
return "No medical history found for this query."
content = "\n\n---DOCUMENT---\n\n".join([doc.page_content for doc in results])
return content
except Exception as e:
print(f"[RAG] Error retrieving medical record: {str(e)}")
return f"Failed to retrieve medical record: {str(e)}" |