File size: 3,472 Bytes
6c58cf4
 
 
 
 
dfa6a46
6c58cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
from pathlib import Path
from typing import List, Optional
from langchain_community.document_loaders import PyPDFLoader, ArxivLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.schema import Document
from project.logger.logging import get_logger

logger = get_logger(__name__)

class DataPreparation:
    def __init__(
        self,
        data_dir: str = "data",
        chunk_size: int = 1000,
        chunk_overlap: int = 200
    ):
        self.data_dir = Path(data_dir)
        self.data_dir.mkdir(exist_ok=True)
        
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len,
            separators=["\n\n", "\n", " ", ""]
        )
        logger.info(f"DataPreparation initialized with chunk_size={chunk_size}")
    
    def load_attention_paper(self, arxiv_id: str = "1706.03762") -> List[Document]:
        pdf_path = self.data_dir / "attention-is-all-you-need.pdf"

        if pdf_path.exists():
            logger.info(f"Loading PDF from local file: {pdf_path}")
            return self._load_pdf(str(pdf_path))

        logger.info(f"PDF not found locally. Downloading from ArXiv: {arxiv_id}")
        try:
            loader = ArxivLoader(query=arxiv_id, load_max_docs=1)
            documents = loader.load()
            if documents:
                logger.info(f"Successfully downloaded paper from ArXiv")
                return documents
            else:
                raise ValueError("No documents returned from ArXiv")
                
        except Exception as e:
            logger.error(f"Failed to download from ArXiv: {str(e)}")
            raise
    
    def _load_pdf(self, pdf_path: str) -> List[Document]:
        try:
            loader = PyPDFLoader(pdf_path)
            documents = loader.load()
            logger.info(f"Loaded {len(documents)} pages from PDF")
            return documents
        except Exception as e:
            logger.error(f"Failed to load PDF: {str(e)}")
            raise
    
    def load_custom_pdf(self, pdf_path: str) -> List[Document]:
        if not Path(pdf_path).exists():
            raise FileNotFoundError(f"PDF not found: {pdf_path}")
        
        return self._load_pdf(pdf_path)
    
    def split_documents(self, documents: List[Document]) -> List[Document]:
        try:
            chunks = self.text_splitter.split_documents(documents)
            logger.info(f"Split documents into {len(chunks)} chunks")
            return chunks
        except Exception as e:
            logger.error(f"Failed to split documents: {str(e)}")
            raise
    
    def prepare_documents(
        self, 
        pdf_path: Optional[str] = None,
        use_attention_paper: bool = True
    ) -> List[Document]:
        try:
            if pdf_path:
                documents = self.load_custom_pdf(pdf_path)
            elif use_attention_paper:
                documents = self.load_attention_paper()
            else:
                raise ValueError("Either provide pdf_path or set use_attention_paper=True")
            
            chunks = self.split_documents(documents)
            
            logger.info(f"Document preparation complete: {len(chunks)} chunks ready")
            return chunks
            
        except Exception as e:
            logger.error(f"Document preparation failed: {str(e)}")
            raise