Spaces:
Running
on
Zero
Running
on
Zero
| import re | |
| import fitz # pip install pymupdf | |
| from unidecode import unidecode | |
| from nltk.tokenize import sent_tokenize | |
| from transformers import pipeline, AutoTokenizer | |
| import torch | |
| from typing import List, Tuple, Optional | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class BERTRetriever: | |
| """ | |
| BERT-based evidence retrieval using extractive question answering | |
| """ | |
| def __init__(self, model_name: str = "deepset/deberta-v3-large-squad2"): | |
| """ | |
| Initialize the BERT evidence retriever | |
| Args: | |
| model_name: HuggingFace model for question answering | |
| """ | |
| self.model_name = model_name | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.qa_pipeline = pipeline( | |
| "question-answering", | |
| model=model_name, | |
| tokenizer=self.tokenizer, | |
| device=0 if torch.cuda.is_available() else -1, | |
| ) | |
| # Maximum context length for the model | |
| self.max_length = self.tokenizer.model_max_length | |
| logger.info(f"Initialized BERT retriever with model: {model_name}") | |
| def _extract_and_clean_text(self, pdf_file: str) -> str: | |
| """ | |
| Extract and clean text from PDF file | |
| Args: | |
| pdf_file: Path to PDF file | |
| Returns: | |
| Cleaned text from PDF | |
| """ | |
| # Get PDF file as binary | |
| with open(pdf_file, mode="rb") as f: | |
| pdf_file_bytes = f.read() | |
| # Extract text from the PDF | |
| pdf_doc = fitz.open(stream=pdf_file_bytes, filetype="pdf") | |
| pdf_text = "" | |
| for page_num in range(pdf_doc.page_count): | |
| page = pdf_doc.load_page(page_num) | |
| pdf_text += page.get_text("text") | |
| # Clean text | |
| # Remove hyphens at end of lines | |
| clean_text = re.sub("-\n", "", pdf_text) | |
| # Replace remaining newline characters with space | |
| clean_text = re.sub("\n", " ", clean_text) | |
| # Replace unicode with ascii | |
| clean_text = unidecode(clean_text) | |
| return clean_text | |
| def _chunk_text(self, text: str, max_chunk_size: int = 3000) -> List[str]: | |
| """ | |
| Split text into chunks that fit within model context window | |
| Args: | |
| text: Input text to chunk | |
| max_chunk_size: Maximum size per chunk | |
| Returns: | |
| List of text chunks | |
| """ | |
| sentences = sent_tokenize(text) | |
| chunks = [] | |
| current_chunk = "" | |
| for sentence in sentences: | |
| # Check if adding this sentence would exceed the limit | |
| if len(current_chunk) + len(sentence) + 1 <= max_chunk_size: | |
| current_chunk += " " + sentence if current_chunk else sentence | |
| else: | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| current_chunk = sentence | |
| # Add the last chunk | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| return chunks | |
| def _format_claim_as_question(self, claim: str) -> str: | |
| """ | |
| Convert a claim into a question format for better QA performance | |
| Args: | |
| claim: Input claim | |
| Returns: | |
| Question formatted for QA model | |
| """ | |
| # Simple heuristics to convert claims to questions | |
| claim = claim.strip() | |
| # If already a question, return as is | |
| if claim.endswith("?"): | |
| return claim | |
| # Convert common claim patterns to questions | |
| if claim.lower().startswith(("the ", "a ", "an ")): | |
| return f"What evidence supports that {claim.lower()}?" | |
| elif "is" in claim.lower() or "are" in claim.lower(): | |
| return f"Is it true that {claim.lower()}?" | |
| elif "can" in claim.lower() or "could" in claim.lower(): | |
| return f"{claim}?" | |
| else: | |
| return f"What evidence supports the claim that {claim.lower()}?" | |
| def retrieve_evidence(self, pdf_file: str, claim: str, top_k: int = 5) -> str: | |
| """ | |
| Retrieve evidence from PDF using BERT-based question answering | |
| Args: | |
| pdf_file: Path to PDF file | |
| claim: Claim to find evidence for | |
| k: Number of evidence passages to retrieve | |
| Returns: | |
| Combined evidence text | |
| """ | |
| try: | |
| # Extract and clean text from PDF | |
| clean_text = self._extract_and_clean_text(pdf_file) | |
| # Convert claim to question format | |
| question = self._format_claim_as_question(claim) | |
| # Split text into manageable chunks | |
| chunks = self._chunk_text(clean_text) | |
| # Get answers from each chunk | |
| answers = [] | |
| for i, chunk in enumerate(chunks): | |
| try: | |
| result = self.qa_pipeline( | |
| question=question, context=chunk, max_answer_len=200, top_k=1 | |
| ) | |
| # Handle both single answer and list of answers | |
| if isinstance(result, list): | |
| result = result[0] | |
| if result["score"] > 0.1: # Confidence threshold | |
| # Extract surrounding context for better evidence | |
| answer_text = result["answer"] | |
| start_idx = max(0, chunk.find(answer_text) - 100) | |
| end_idx = min( | |
| len(chunk), chunk.find(answer_text) + len(answer_text) + 100 | |
| ) | |
| context = chunk[start_idx:end_idx].strip() | |
| answers.append( | |
| {"text": context, "score": result["score"], "chunk_idx": i} | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Error processing chunk {i}: {str(e)}") | |
| continue | |
| # Sort by confidence score and take top k | |
| answers.sort(key=lambda x: x["score"], reverse=True) | |
| top_answers = answers[:top_k] | |
| # Combine evidence passages | |
| if top_answers: | |
| evidence_texts = [answer["text"] for answer in top_answers] | |
| combined_evidence = " ".join(evidence_texts) | |
| return combined_evidence | |
| else: | |
| logger.warning("No evidence found with sufficient confidence") | |
| return "No relevant evidence found in the document." | |
| except Exception as e: | |
| logger.error(f"Error in BERT evidence retrieval: {str(e)}") | |
| return f"Error retrieving evidence: {str(e)}" | |
| def retrieve_with_deberta(pdf_file: str, claim: str, top_k: int = 5) -> str: | |
| """ | |
| Wrapper function for DeBERTa-based evidence retrieval | |
| Compatible with the existing BM25S interface | |
| Args: | |
| pdf_file: Path to PDF file | |
| claim: Claim to find evidence for | |
| top_k: Number of evidence passages to retrieve | |
| Returns: | |
| Retrieved evidence text | |
| """ | |
| # Initialize retriever (in production, this should be cached) | |
| retriever = BERTRetriever() | |
| return retriever.retrieve_evidence(pdf_file, claim, top_k) | |
| # Alternative lightweight model for faster inference | |
| class DistilBERTRetriever(BERTRetriever): | |
| """ | |
| Lightweight version using smaller, faster models | |
| """ | |
| def __init__(self): | |
| super().__init__(model_name="distilbert-base-cased-distilled-squad") | |
| def retrieve_with_distilbert(pdf_file: str, claim: str, top_k: int = 5) -> str: | |
| """ | |
| Fast DistilBERT-based evidence retrieval | |
| Args: | |
| pdf_file: Path to PDF file | |
| claim: Claim to find evidence for | |
| top_k: Number of evidence passages to retrieve | |
| Returns: | |
| Retrieved evidence text | |
| """ | |
| retriever = DistilBERTRetriever() | |
| return retriever.retrieve_evidence(pdf_file, claim, top_k) | |