Spaces:
Runtime error
Runtime error
| from typing import List | |
| from pydantic import BaseModel | |
| import pdfplumber | |
| from fastapi import UploadFile | |
| from gliner import GLiNER | |
| import logging | |
| import torch | |
| import re | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class Entity(BaseModel): | |
| entity: str | |
| context: str | |
| start: int | |
| end: int | |
| # Curated medical labels | |
| MEDICAL_LABELS = [ | |
| "gene", "protein", "protein_isoform", "cell", "disease", | |
| "phenotypic_feature", "clinical_finding", "anatomical_entity", | |
| "pathway", "biological_process", "drug", "small_molecule", | |
| "food_additive", "chemical_mixture", "molecular_entity", | |
| "clinical_intervention", "clinical_trial", "hospitalization", | |
| "geographic_location", "environmental_feature", "environmental_process", | |
| "publication", "journal_article", "book", "patent", "dataset", | |
| "study_result", "human", "mammal", "plant", "virus", "bacterium", | |
| "cell_line", "biological_sex", "clinical_attribute", | |
| "socioeconomic_attribute", "environmental_exposure", "drug_exposure", | |
| "procedure", "treatment", "device", "diagnostic_aid", "event" | |
| ] | |
| # Check for GPU availability | |
| if torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| elif torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| logger.info(f"Using device: {device}") | |
| # Initialize model | |
| gliner_model = GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5") | |
| gliner_model.to(device) # Move model to GPU if available | |
| def chunk_text(text: str, max_tokens: int = 700) -> List[str]: | |
| """ | |
| Split text into chunks that respect sentence boundaries and token limit. | |
| We use 700 tokens to leave some margin for the model's special tokens. | |
| Args: | |
| text (str): Input text to chunk | |
| max_tokens (int): Maximum number of tokens per chunk | |
| Returns: | |
| List[str]: List of text chunks | |
| """ | |
| # Split into sentences (simple approach) | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| chunks = [] | |
| current_chunk = [] | |
| current_length = 0 | |
| for sentence in sentences: | |
| # Rough estimation of tokens (words + punctuation) | |
| sentence_tokens = len(re.findall(r'\w+|[^\w\s]', sentence)) | |
| if current_length + sentence_tokens > max_tokens: | |
| if current_chunk: # Save current chunk if it exists | |
| chunks.append(' '.join(current_chunk)) | |
| current_chunk = [] | |
| current_length = 0 | |
| current_chunk.append(sentence) | |
| current_length += sentence_tokens | |
| # Don't forget the last chunk | |
| if current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| return chunks | |
| def extract_entities_from_pdf(file: UploadFile) -> List[Entity]: | |
| """ | |
| Extract medical entities from a PDF file using GLiNER. | |
| Args: | |
| file (UploadFile): The uploaded PDF file | |
| Returns: | |
| List[Entity]: List of extracted entities with their context | |
| """ | |
| logger.debug(f"Starting extraction for file: {file.filename}") | |
| try: | |
| # Create a temporary file to handle the upload | |
| with pdfplumber.open(file.file) as pdf: | |
| logger.info(f"Successfully opened PDF with {len(pdf.pages)} pages") | |
| # Join all pages into single string | |
| pdf_text = " ".join(p.extract_text() for p in pdf.pages) | |
| logger.info(f"Extracted text length: {len(pdf_text)} characters") | |
| # Split text into chunks | |
| text_chunks = chunk_text(pdf_text) | |
| logger.info(f"Split text into {len(text_chunks)} chunks") | |
| # Extract entities from each chunk | |
| all_entities = [] | |
| base_offset = 0 # Keep track of the absolute position in the original text | |
| for chunk in text_chunks: | |
| # Extract entities using GLiNER | |
| chunk_entities = gliner_model.predict_entities(chunk, MEDICAL_LABELS, threshold=0.7) | |
| # Process entities from this chunk | |
| for ent in chunk_entities: | |
| if len(ent["text"]) <= 2: # Skip very short entities | |
| continue | |
| # Just store the entity and its position for now | |
| start_idx = chunk.find(ent["text"]) | |
| if start_idx != -1: | |
| all_entities.append(Entity( | |
| entity=ent["text"], | |
| context="", # Will be filled later | |
| start=base_offset + start_idx, | |
| end=base_offset + start_idx + len(ent["text"]) | |
| )) | |
| base_offset += len(chunk) + 1 # +1 for the space between chunks | |
| # Now get context for all entities using the complete original text | |
| final_entities = [] | |
| for ent in all_entities: | |
| # Get surrounding context from the complete text | |
| context_start = max(0, ent.start - 50) | |
| context_end = min(len(pdf_text), ent.end + 50) | |
| context = pdf_text[context_start:context_end] | |
| final_entities.append(Entity( | |
| entity=ent.entity, | |
| context=context, | |
| start=ent.start, | |
| end=ent.end | |
| )) | |
| logger.info(f"Returning {len(final_entities)} processed entities") | |
| return final_entities | |
| except Exception as e: | |
| logger.error(f"Error during extraction: {str(e)}", exc_info=True) | |
| raise | |