MedQuery-Assist / document_classifier.py
twissamodi's picture
minor fixes
62a231e
from langchain_community.document_loaders import PyPDFLoader
from transformers import pipeline
import torch
from collections import defaultdict
import time
class DocumentClassifier:
LABELS = [
"lab report",
"prescription",
"discharge summary",
"progress note",
"imaging report",
"consultation note",
"operative report",
"immunization record"
]
def __init__(
self,
pages_per_group=2,
min_confidence=0.35,
model_name="cross-encoder/nli-deberta-v3-small"
):
self.pages_per_group = pages_per_group
self.min_confidence = min_confidence
self.model_name = model_name
self.classifier = None
print(f"[Classifier] Loading {model_name}...")
self._load_model()
def _load_model(self):
device = 0 if torch.cuda.is_available() else -1
self.classifier = pipeline(
"zero-shot-classification",
model=self.model_name,
device=device
)
print(f"[Classifier] Ready (device: {'GPU' if device >= 0 else 'CPU'})")
def classify_document(self, file_path):
start_time = time.time()
try:
loader = PyPDFLoader(file_path)
pages = loader.load()
if not pages:
return self._default_result()
print(f"[Classifier] Analyzing {len(pages)} pages...")
page_groups = self._create_page_groups(pages)
print(f"[Classifier] Created {len(page_groups)} groups, classifying in parallel...")
group_results = self._classify_groups_parallel(page_groups)
page_map = self._build_page_map(group_results)
all_types = [p['type'] for p in page_map.values()]
type_counts = defaultdict(int)
for t in all_types:
type_counts[t] += 1
primary_type = max(type_counts.items(), key=lambda x: x[1])[0]
unique_types = sorted(set(all_types),
key=lambda t: type_counts[t],
reverse=True)
result = {
"primary_type": primary_type,
"page_classifications": page_map,
"all_types": unique_types,
"processing_time": round(time.time() - start_time, 2),
"total_pages": len(pages)
}
print(f"[Classifier] Done in {result['processing_time']}s - "
f"Primary: {primary_type}, Types found: {len(unique_types)}")
return result
except Exception as e:
print(f"[Classifier] Error: {e}")
import traceback
traceback.print_exc()
return self._default_result()
def _create_page_groups(self, pages):
groups = []
for i in range(0, len(pages), self.pages_per_group):
group_pages = pages[i:i + self.pages_per_group]
page_nums = list(range(i + 1, i + len(group_pages) + 1))
text = " ".join([p.page_content for p in group_pages])
if len(text) > 2000:
text = text[:1000] + " ... " + text[-1000:]
groups.append({
'text': text,
'page_numbers': page_nums
})
return groups
def _classify_groups_parallel(self, groups):
results = []
texts = [g['text'] for g in groups]
# Use pipeline's native batching — faster than ThreadPoolExecutor,
# especially on GPU, and avoids thread-safety issues with PyTorch.
batch_results = self.classifier(texts, self.LABELS, multi_label=True, batch_size=8)
for group, result in zip(groups, batch_results):
primary_type = result['labels'][0]
primary_score = result['scores'][0]
if primary_score < self.min_confidence:
primary_type = 'other'
scores = {label: score for label, score in zip(result['labels'], result['scores'])}
results.append({
'type': primary_type,
'confidence': primary_score,
'scores': scores,
'page_numbers': group['page_numbers']
})
return results
def _classify_single_group(self, group):
# Kept for single-group use if needed directly
text = group['text']
if not text.strip():
return {'type': 'other', 'confidence': 0.0, 'scores': {}}
result = self.classifier(text, self.LABELS, multi_label=True)
primary_type = result['labels'][0]
primary_score = result['scores'][0]
if primary_score < self.min_confidence:
primary_type = 'other'
scores = {
label: score
for label, score in zip(result['labels'], result['scores'])
}
return {
'type': primary_type,
'confidence': primary_score,
'scores': scores
}
def _build_page_map(self, group_results):
page_map = {}
for group in group_results:
page_nums = group.get('page_numbers', [])
doc_type = group.get('type', 'other')
confidence = group.get('confidence', 0.0)
for page_num in page_nums:
page_map[page_num] = {
'type': doc_type,
'confidence': round(confidence, 2)
}
return page_map
def _default_result(self):
return {
"primary_type": "other",
"page_classifications": {},
"all_types": ["other"],
"processing_time": 0.0,
"total_pages": 0
}