Spaces:
Sleeping
Sleeping
| from langchain_community.document_loaders import PyPDFLoader | |
| from transformers import pipeline | |
| import torch | |
| from collections import defaultdict | |
| import time | |
| class DocumentClassifier: | |
| LABELS = [ | |
| "lab report", | |
| "prescription", | |
| "discharge summary", | |
| "progress note", | |
| "imaging report", | |
| "consultation note", | |
| "operative report", | |
| "immunization record" | |
| ] | |
| def __init__( | |
| self, | |
| pages_per_group=2, | |
| min_confidence=0.35, | |
| model_name="cross-encoder/nli-deberta-v3-small" | |
| ): | |
| self.pages_per_group = pages_per_group | |
| self.min_confidence = min_confidence | |
| self.model_name = model_name | |
| self.classifier = None | |
| print(f"[Classifier] Loading {model_name}...") | |
| self._load_model() | |
| def _load_model(self): | |
| device = 0 if torch.cuda.is_available() else -1 | |
| self.classifier = pipeline( | |
| "zero-shot-classification", | |
| model=self.model_name, | |
| device=device | |
| ) | |
| print(f"[Classifier] Ready (device: {'GPU' if device >= 0 else 'CPU'})") | |
| def classify_document(self, file_path): | |
| start_time = time.time() | |
| try: | |
| loader = PyPDFLoader(file_path) | |
| pages = loader.load() | |
| if not pages: | |
| return self._default_result() | |
| print(f"[Classifier] Analyzing {len(pages)} pages...") | |
| page_groups = self._create_page_groups(pages) | |
| print(f"[Classifier] Created {len(page_groups)} groups, classifying in parallel...") | |
| group_results = self._classify_groups_parallel(page_groups) | |
| page_map = self._build_page_map(group_results) | |
| all_types = [p['type'] for p in page_map.values()] | |
| type_counts = defaultdict(int) | |
| for t in all_types: | |
| type_counts[t] += 1 | |
| primary_type = max(type_counts.items(), key=lambda x: x[1])[0] | |
| unique_types = sorted(set(all_types), | |
| key=lambda t: type_counts[t], | |
| reverse=True) | |
| result = { | |
| "primary_type": primary_type, | |
| "page_classifications": page_map, | |
| "all_types": unique_types, | |
| "processing_time": round(time.time() - start_time, 2), | |
| "total_pages": len(pages) | |
| } | |
| print(f"[Classifier] Done in {result['processing_time']}s - " | |
| f"Primary: {primary_type}, Types found: {len(unique_types)}") | |
| return result | |
| except Exception as e: | |
| print(f"[Classifier] Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return self._default_result() | |
| def _create_page_groups(self, pages): | |
| groups = [] | |
| for i in range(0, len(pages), self.pages_per_group): | |
| group_pages = pages[i:i + self.pages_per_group] | |
| page_nums = list(range(i + 1, i + len(group_pages) + 1)) | |
| text = " ".join([p.page_content for p in group_pages]) | |
| if len(text) > 2000: | |
| text = text[:1000] + " ... " + text[-1000:] | |
| groups.append({ | |
| 'text': text, | |
| 'page_numbers': page_nums | |
| }) | |
| return groups | |
| def _classify_groups_parallel(self, groups): | |
| results = [] | |
| texts = [g['text'] for g in groups] | |
| # Use pipeline's native batching — faster than ThreadPoolExecutor, | |
| # especially on GPU, and avoids thread-safety issues with PyTorch. | |
| batch_results = self.classifier(texts, self.LABELS, multi_label=True, batch_size=8) | |
| for group, result in zip(groups, batch_results): | |
| primary_type = result['labels'][0] | |
| primary_score = result['scores'][0] | |
| if primary_score < self.min_confidence: | |
| primary_type = 'other' | |
| scores = {label: score for label, score in zip(result['labels'], result['scores'])} | |
| results.append({ | |
| 'type': primary_type, | |
| 'confidence': primary_score, | |
| 'scores': scores, | |
| 'page_numbers': group['page_numbers'] | |
| }) | |
| return results | |
| def _classify_single_group(self, group): | |
| # Kept for single-group use if needed directly | |
| text = group['text'] | |
| if not text.strip(): | |
| return {'type': 'other', 'confidence': 0.0, 'scores': {}} | |
| result = self.classifier(text, self.LABELS, multi_label=True) | |
| primary_type = result['labels'][0] | |
| primary_score = result['scores'][0] | |
| if primary_score < self.min_confidence: | |
| primary_type = 'other' | |
| scores = { | |
| label: score | |
| for label, score in zip(result['labels'], result['scores']) | |
| } | |
| return { | |
| 'type': primary_type, | |
| 'confidence': primary_score, | |
| 'scores': scores | |
| } | |
| def _build_page_map(self, group_results): | |
| page_map = {} | |
| for group in group_results: | |
| page_nums = group.get('page_numbers', []) | |
| doc_type = group.get('type', 'other') | |
| confidence = group.get('confidence', 0.0) | |
| for page_num in page_nums: | |
| page_map[page_num] = { | |
| 'type': doc_type, | |
| 'confidence': round(confidence, 2) | |
| } | |
| return page_map | |
| def _default_result(self): | |
| return { | |
| "primary_type": "other", | |
| "page_classifications": {}, | |
| "all_types": ["other"], | |
| "processing_time": 0.0, | |
| "total_pages": 0 | |
| } | |