Spaces:
Running
Running
Add support for CUDA availability check in PDF processing and raise HTTPException for unsupported scanned PDFs
698a3c4
| import asyncio | |
| import os | |
| import re | |
| import tempfile | |
| from pathlib import Path | |
| from typing import List | |
| import aiofiles | |
| import fitz | |
| import torch | |
| from fastapi import HTTPException, UploadFile | |
| from loguru import logger | |
| from src.utils import TextExtractor, model_manager | |
| class PDFProcessorService: | |
| def __init__(self): | |
| logger.info("Initializing PDFProcessorService") | |
| self._ensure_models_loaded() | |
| def _ensure_models_loaded(self): | |
| if not model_manager.models_loaded: | |
| logger.info("Models not loaded, initializing model manager...") | |
| _ = model_manager.doctr_model | |
| logger.debug("Model manager initialization completed") | |
| def doctr_model(self): | |
| return model_manager.doctr_model | |
| def device(self): | |
| return model_manager.device | |
| async def __aenter__(self): | |
| return self | |
| async def __aexit__(self, exc_type, exc_value, traceback): | |
| pass | |
| async def is_pdf_scanned(self, pdf_path: str) -> bool: | |
| logger.debug(f"Checking if PDF is scanned: {pdf_path}") | |
| def _check_scanned(): | |
| try: | |
| doc = fitz.open(pdf_path) | |
| for page in doc: | |
| text = page.get_text() | |
| if text.strip(): | |
| return False | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error checking if PDF is scanned: {e}") | |
| raise | |
| return await asyncio.get_event_loop().run_in_executor(None, _check_scanned) | |
| async def save_uploaded_file(self, uploaded_file: UploadFile) -> str: | |
| logger.info(f"Saving uploaded file: {uploaded_file.filename}") | |
| try: | |
| file_name = uploaded_file.filename | |
| suffix = Path(file_name).suffix | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| temp_path = tmp.name | |
| async with aiofiles.open(temp_path, "wb") as f: | |
| await f.write(await uploaded_file.read()) | |
| logger.debug(f"File saved to temporary path: {temp_path}") | |
| return temp_path | |
| except Exception as e: | |
| logger.error(f"Error saving uploaded file: {e}") | |
| raise | |
| async def extract_text_from_digital_pdf(self, pdf_path: str) -> List[List[str]]: | |
| logger.debug(f"Extracting text from digital PDF: {pdf_path}") | |
| async def _extract_text(): | |
| try: | |
| doc = fitz.open(pdf_path) | |
| extracted_data = [] | |
| for page in doc: | |
| ptext = page.get_text() | |
| if ptext: | |
| data = [] | |
| for line in ptext.splitlines(): | |
| cleaned_line = await self._split_on_repeated_pattern( | |
| line.strip() | |
| ) | |
| if cleaned_line: | |
| data.append(cleaned_line[0]) | |
| extracted_data.append(data) | |
| logger.info( | |
| f"Successfully extracted text from {len(extracted_data)} pages" | |
| ) | |
| return extracted_data | |
| except Exception as e: | |
| logger.error(f"Error extracting text from digital PDF: {e}") | |
| raise | |
| return await asyncio.get_event_loop().run_in_executor(None, _extract_text) | |
| async def _split_on_repeated_pattern( | |
| self, line: str, min_space: int = 10 | |
| ) -> List[str]: | |
| logger.debug(f"Processing line for repeated patterns: {line[:50]}...") | |
| import re | |
| from difflib import SequenceMatcher | |
| original_line = line.strip() | |
| space_spans = [ | |
| (m.start(), len(m.group())) | |
| for m in re.finditer(r" {%d,}" % min_space, original_line) | |
| ] | |
| if not space_spans: | |
| return [original_line] | |
| gaps = [span[1] for span in space_spans] | |
| gap_counts = {} | |
| for g in gaps: | |
| gap_counts[g] = gap_counts.get(g, 0) + 1 | |
| sorted_gaps = sorted( | |
| gap_counts.items(), key=lambda x: x[1] * x[0], reverse=True | |
| ) | |
| if not sorted_gaps: | |
| return [original_line] | |
| dominant_gap = sorted_gaps[0][0] | |
| chunks = re.split(rf" {{%d,}}" % dominant_gap, original_line) | |
| base = chunks[0].strip() | |
| repeated = False | |
| for chunk in chunks[1:]: | |
| chunk = chunk.strip() | |
| if chunk and SequenceMatcher(None, base, chunk).ratio() > 0.8: | |
| repeated = True | |
| break | |
| return [base] if repeated else [original_line] | |
| async def process_pdf(self, file): | |
| logger.info(f"Processing PDF file: {file.filename}") | |
| try: | |
| pdf_path = await self.save_uploaded_file(file) | |
| is_scanned = await self.is_pdf_scanned(pdf_path) | |
| text_extractor = TextExtractor(self.doctr_model) | |
| if is_scanned: | |
| if not torch.cuda.is_available(): | |
| raise HTTPException( | |
| status_code=400, detail="Scanned PDFs are not supported." | |
| ) | |
| logger.info(f"PDF {pdf_path} is scanned, using OCR extraction") | |
| extracted_text_list = ( | |
| await text_extractor.extract_lines_with_bbox_from_scanned_pdf( | |
| pdf_path | |
| ) | |
| ) | |
| else: | |
| logger.info(f"PDF {pdf_path} is digital, extracting text directly") | |
| extracted_text_list = await text_extractor.extract_lines_with_bbox( | |
| pdf_path | |
| ) | |
| pdf_text = "" | |
| for block in extracted_text_list: | |
| for line in block: | |
| pdf_text += " " + line["line"] | |
| text_noisy = text_extractor.is_text_noisy(pdf_text) | |
| if text_noisy: | |
| if not torch.cuda.is_available(): | |
| raise HTTPException( | |
| status_code=400, detail="Scanned PDFs are not supported." | |
| ) | |
| logger.warning("Text is noisy, falling back to OCR extraction") | |
| extracted_text_list = ( | |
| await text_extractor.extract_lines_with_bbox_from_scanned_pdf( | |
| pdf_path | |
| ) | |
| ) | |
| logger.info( | |
| f"Successfully processed PDF with {len(extracted_text_list)} text blocks" | |
| ) | |
| return extracted_text_list | |
| except Exception as e: | |
| logger.error(f"Error processing PDF: {e}") | |
| raise | |
| finally: | |
| if os.path.exists(pdf_path): | |
| os.remove(pdf_path) | |
| async def extract_entity(self, text: str): | |
| logger.debug(f"Extracting entities from text: {text[:100]}...") | |
| try: | |
| text = re.sub(r"[^\w\s]", " ", text) | |
| doc = model_manager.spacy_model(text) | |
| entities = {ent.text: ent.label_ for ent in doc.ents} | |
| for key, value in entities.items(): | |
| if value == "ORG": | |
| logger.info(f"Found organization entity: {key}") | |
| return key | |
| if entities: | |
| entity = list(entities.keys())[0] | |
| logger.info(f"Found entity: {entity}") | |
| return entity | |
| logger.debug("No entities found, returning original text") | |
| return text | |
| except Exception as e: | |
| logger.error(f"Error extracting entities: {e}") | |
| return text | |