yiqing111 commited on
Commit
8255e91
·
verified ·
1 Parent(s): b061aa3

Upload 7 files

Browse files
script/chunk.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from typing import List, Dict
4
+ from tqdm import tqdm
5
+
6
+ class SimpleTextChunker:
7
+ def __init__(self,
8
+ chunk_size: int = 200,
9
+ chunk_overlap: int = 20,
10
+ recursive: bool = False,
11
+ max_recursion_depth: int = 3):
12
+ self.chunk_size = chunk_size
13
+ self.chunk_overlap = chunk_overlap
14
+ self.recursive = recursive
15
+ self.max_recursion_depth = max_recursion_depth
16
+
17
+ def is_mainly_chinese(self, text: str) -> bool:
18
+ """Check if text is primarily Chinese"""
19
+ if not text:
20
+ return False
21
+
22
+ chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff')
23
+ return chinese_chars / len(text) > 0.5
24
+
25
+ def simple_chunk_with_overlap(self, text: str, source: str) -> List[Dict]:
26
+ chunks = []
27
+
28
+ # Check if we should try to split on paragraph boundaries
29
+ paragraphs = []
30
+ if '\n\n' in text:
31
+ # Split by double newlines to get paragraphs
32
+ paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
33
+
34
+ # If we have meaningful paragraphs, use them as base units
35
+ if paragraphs and len(paragraphs) > 1 and max(len(p) for p in paragraphs) < self.chunk_size:
36
+ current_chunk = []
37
+ current_size = 0
38
+
39
+ for para in paragraphs:
40
+ para_size = len(para)
41
+
42
+ # If adding this paragraph would exceed the chunk size and we already have content
43
+ if current_size + para_size > self.chunk_size and current_chunk:
44
+ # Create a chunk from what we have so far
45
+ chunk_text = '\n\n'.join(current_chunk)
46
+ chunks.append({
47
+ "source": source,
48
+ "content": chunk_text,
49
+ "chunk_index": len(chunks),
50
+ "is_chinese": self.is_mainly_chinese(chunk_text)
51
+ })
52
+
53
+ # Calculate how many paragraphs to keep for overlap
54
+ overlap_size = 0
55
+ overlap_paras = []
56
+
57
+ for p in reversed(current_chunk):
58
+ if overlap_size + len(p) <= self.chunk_overlap:
59
+ overlap_paras.insert(0, p)
60
+ overlap_size += len(p)
61
+ else:
62
+ break
63
+
64
+ # Start the next chunk with the overlap paragraphs
65
+ current_chunk = overlap_paras
66
+ current_size = overlap_size
67
+
68
+ # Add paragraph to current chunk
69
+ current_chunk.append(para)
70
+ current_size += para_size
71
+
72
+ # Add the last chunk if there's anything left
73
+ if current_chunk:
74
+ chunk_text = '\n\n'.join(current_chunk)
75
+ chunks.append({
76
+ "source": source,
77
+ "content": chunk_text,
78
+ "chunk_index": len(chunks),
79
+ "is_chinese": self.is_mainly_chinese(chunk_text)
80
+ })
81
+ else:
82
+ # Fall back to character-based chunking
83
+ for i in range(0, len(text), self.chunk_size - self.chunk_overlap):
84
+ chunk_start = i
85
+ chunk_end = min(i + self.chunk_size, len(text))
86
+
87
+ if chunk_end <= chunk_start:
88
+ break
89
+
90
+ chunk_text = text[chunk_start:chunk_end]
91
+
92
+ chunks.append({
93
+ "source": source,
94
+ "content": chunk_text,
95
+ "chunk_index": len(chunks),
96
+ "is_chinese": self.is_mainly_chinese(chunk_text)
97
+ })
98
+
99
+ return chunks
100
+
101
+ def recursive_chunk(self, text: str, source: str, depth: int = 0) -> List[Dict]:
102
+ if len(text) <= self.chunk_size or depth >= self.max_recursion_depth:
103
+ return [{
104
+ "source": source,
105
+ "content": text,
106
+ "chunk_index": 0,
107
+ "recursion_depth": depth,
108
+ "is_chinese": self.is_mainly_chinese(text)
109
+ }]
110
+
111
+ # First level
112
+ if depth == 0 and '\n#' in text: # Markdown header format
113
+ sections = re.split(r'\n(#+ )', text)
114
+ if len(sections) > 1:
115
+ # Recombine the headers with their content
116
+ combined_sections = []
117
+ for i in range(1, len(sections), 2):
118
+ if i+1 < len(sections):
119
+ combined_sections.append(sections[i] + sections[i+1])
120
+ else:
121
+ combined_sections.append(sections[i])
122
+
123
+ # Recursively process each section
124
+ all_chunks = []
125
+ for i, section in enumerate(combined_sections):
126
+ section_chunks = self.recursive_chunk(section, source, depth + 1)
127
+
128
+ # Update chunk indices
129
+ for j, chunk in enumerate(section_chunks):
130
+ chunk["chunk_index"] = len(all_chunks) + j
131
+ chunk["section_index"] = i
132
+
133
+ all_chunks.extend(section_chunks)
134
+
135
+ return all_chunks
136
+
137
+ # If no natural sections or not at top level, use overlap chunking
138
+ return self.simple_chunk_with_overlap(text, source)
139
+
140
+ def process_document(self, document: Dict) -> List[Dict]:
141
+ if not document.get("text") or not document.get("success", False):
142
+ print(f"Skipping document {document.get('filename', 'unknown')}: No text or extraction failed")
143
+ return []
144
+
145
+ text = document["text"]
146
+ source = document.get("filename", "unknown")
147
+
148
+ if self.recursive:
149
+ chunks = self.recursive_chunk(text, source)
150
+ else:
151
+ chunks = self.simple_chunk_with_overlap(text, source)
152
+
153
+ # Add document metadata to each chunk
154
+ for chunk in chunks:
155
+ chunk["document_pages"] = document.get("pages", 0)
156
+ chunk["total_chunks"] = len(chunks)
157
+
158
+ return chunks
159
+
160
+ def process_documents(self, documents: List[Dict]) -> List[Dict]:
161
+ all_chunks = []
162
+
163
+ for doc in tqdm(documents, desc="Chunking documents"):
164
+ doc_chunks = self.process_document(doc)
165
+ all_chunks.extend(doc_chunks)
166
+
167
+ print(f"Created {len(all_chunks)} chunks from {len(documents)} documents")
168
+ return all_chunks
169
+
170
+ def save_chunks(self, chunks: List[Dict], output_path: str):
171
+ with open(output_path, 'w', encoding='utf-8') as f:
172
+ for i, chunk in enumerate(chunks):
173
+ f.write(f"Chunk {i+1}/{len(chunks)}\n")
174
+ f.write(f"Source: {chunk['source']}\n")
175
+ f.write(f"Index: {chunk['chunk_index']}/{chunk['total_chunks']}\n")
176
+ if "recursion_depth" in chunk:
177
+ f.write(f"Depth: {chunk['recursion_depth']}\n")
178
+ f.write(f"Chinese: {chunk.get('is_chinese', False)}\n")
179
+ f.write("Content:\n")
180
+ f.write(chunk['content'])
181
+ f.write("\n" + "-" * 80 + "\n\n")
182
+
183
+ print(f"Saved {len(chunks)} chunks to {output_path}")
script/embedding.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ embedding_model = SentenceTransformer('intfloat/multilingual-e5-large')
3
+
4
+ def get_embedding(text):
5
+ return embedding_model.encode(text).tolist()
script/llm.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import os
3
+ from dotenv import load_dotenv
4
+ load_dotenv()
5
+
6
+ client = OpenAI(api_key=os.getenv("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com")
7
+
8
+ def ask_llm(question, context):
9
+ prompt = f"""
10
+ Please answer the following question based on the provided notes:
11
+
12
+ Notes:
13
+ {context}
14
+
15
+ Question:
16
+ {question}
17
+ """
18
+ response = client.chat.completions.create(
19
+ model="deepseek-chat",
20
+ messages=[
21
+ {"role": "system", "content": "You are a helpful assistant who answers based on the given notes."},
22
+ {"role": "user", "content": f"Notes:\n{context}\n\nQuestion: {question}"}
23
+ ]
24
+ )
25
+
26
+ return response.choices[0].message.content
27
+
script/parse.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from typing import List, Dict
4
+ import fitz
5
+ import re
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ from tqdm import tqdm
8
+
9
+ class PDFTextExtractor:
10
+
11
+ def __init__(self, input_dir: str, output_dir: str = None):
12
+ self.input_dir = input_dir
13
+ self.output_dir = output_dir or os.path.join(input_dir, "extracted_text")
14
+
15
+ # Ensure output directory exists
16
+ os.makedirs(self.output_dir, exist_ok=True)
17
+
18
+ def get_pdf_files(self) -> List[str]:
19
+ pdf_files = glob.glob(os.path.join(self.input_dir, "*.pdf"))
20
+ pdf_files.extend(glob.glob(os.path.join(self.input_dir, "*.PDF")))
21
+
22
+ print(f"Found {len(pdf_files)} PDF files in directory {self.input_dir}")
23
+ return pdf_files
24
+
25
+ def extract_text_from_pdf(self, pdf_path: str) -> Dict:
26
+ filename = os.path.basename(pdf_path)
27
+ result = {
28
+ "filename": filename,
29
+ "path": pdf_path,
30
+ "success": False,
31
+ "text": "",
32
+ "pages": 0,
33
+ "error": None
34
+ }
35
+
36
+ try:
37
+ doc = fitz.open(pdf_path)
38
+ result["pages"] = len(doc)
39
+
40
+ full_text = ""
41
+ for page_num in range(len(doc)):
42
+ page = doc.load_page(page_num)
43
+ # Use "text" mode to extract plain text, ignoring tables and images
44
+ page_text = page.get_text("text")
45
+ full_text += page_text + "\n\n" # Add line breaks to separate pages
46
+
47
+ # Clean the text
48
+ full_text = self.clean_text(full_text)
49
+
50
+ result["text"] = full_text
51
+ result["success"] = True
52
+
53
+ # Close the document
54
+ doc.close()
55
+
56
+ except Exception as e:
57
+ error_msg = f"Error extracting {filename}: {str(e)}"
58
+ print(error_msg)
59
+ result["error"] = error_msg
60
+
61
+ return result
62
+
63
+ def clean_text(self, text: str) -> str:
64
+ # Remove consecutive empty lines
65
+ text = re.sub(r'\n{3,}', '\n\n', text)
66
+
67
+ # Remove unprintable characters, but keep Chinese, English, numbers and basic punctuation
68
+ text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9.,!?;:()\'",。!?、;:《》【】「」\s]', '', text)
69
+
70
+ # Merge multiple spaces
71
+ text = re.sub(r'\s+', ' ', text)
72
+
73
+ # Fix spacing issues between Chinese and English
74
+ text = re.sub(r'([a-zA-Z])([\u4e00-\u9fa5])', r'\1 \2', text)
75
+ text = re.sub(r'([\u4e00-\u9fa5])([a-zA-Z])', r'\1 \2', text)
76
+
77
+ return text.strip()
78
+
79
+ def save_extracted_text(self, extraction_result: Dict) -> None:
80
+ """Save the extracted text to a file"""
81
+ if not extraction_result["success"]:
82
+ return
83
+
84
+ # Create output filename based on original filename
85
+ base_name = os.path.splitext(extraction_result["filename"])[0]
86
+ output_path = os.path.join(self.output_dir, f"{base_name}.txt")
87
+
88
+ # Write to text file
89
+ with open(output_path, 'w', encoding='utf-8') as f:
90
+ f.write(extraction_result["text"])
91
+
92
+ print(f"Saved extracted text to {output_path}")
93
+
94
+ def process_single_pdf(self, pdf_path: str) -> Dict:
95
+ """Process a single PDF file and save results"""
96
+ extraction_result = self.extract_text_from_pdf(pdf_path)
97
+
98
+ if extraction_result["success"]:
99
+ self.save_extracted_text(extraction_result)
100
+ print(f"Successfully processed {extraction_result['filename']} ({extraction_result['pages']} pages)")
101
+ else:
102
+ print(f"Failed to process {extraction_result['filename']}: {extraction_result['error']}")
103
+
104
+ return extraction_result
105
+
106
+ def extract_all_pdfs(self, max_workers: int = 4) -> List[Dict]:
107
+ pdf_files = self.get_pdf_files()
108
+ results = []
109
+
110
+ if not pdf_files:
111
+ print("No PDF files found")
112
+ return results
113
+
114
+ # Use thread pool for parallel processing
115
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
116
+ # Use tqdm to create a progress bar
117
+ for result in tqdm(executor.map(self.process_single_pdf, pdf_files),
118
+ total=len(pdf_files),
119
+ desc="Processing PDF files"):
120
+ results.append(result)
121
+
122
+ # Count successful and failed processes
123
+ success_count = sum(1 for r in results if r["success"])
124
+ fail_count = len(results) - success_count
125
+
126
+ print(f"PDF processing completed: {success_count} successful, {fail_count} failed")
127
+
128
+ return results
129
+
130
+ # Usage example
131
+ if __name__ == "__main__":
132
+ # Configure input and output directories
133
+ INPUT_DIR = "../data"
134
+ OUTPUT_DIR = "../data"
135
+
136
+ # Create extractor instance
137
+ extractor = PDFTextExtractor(INPUT_DIR, OUTPUT_DIR)
138
+
139
+ # Execute extraction
140
+ results = extractor.extract_all_pdfs(max_workers=4) # Use 4 threads for parallel processing
141
+
142
+ # Print summary
143
+ print(f"\nProcessed {len(results)} PDF files in total")
144
+ print(f"Successful: {sum(1 for r in results if r['success'])}")
145
+ print(f"Failed: {sum(1 for r in results if not r['success'])}")
script/pipeline.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from embedding import get_embedding
2
+ from vector import VectorStore
3
+ from chunk import SimpleTextChunker
4
+ from parse import PDFTextExtractor
5
+
6
+ def build_knowledge_base(pdf_folder):
7
+ extractor = PDFTextExtractor(pdf_folder)
8
+ documents = extractor.extract_all_pdfs()
9
+
10
+ chunker = SimpleTextChunker()
11
+ all_chunks = chunker.process_documents(documents)
12
+
13
+ store = VectorStore()
14
+ embeddings = [get_embedding(chunk["content"]) for chunk in all_chunks]
15
+
16
+ store.add(embeddings, all_chunks)
17
+
18
+ print(f"✅ Knowledge base built with {len(all_chunks)} chunks.")
19
+ return store
script/streamlit_app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from dotenv import load_dotenv
4
+
5
+ from embedding import get_embedding
6
+ from vector import VectorStore
7
+ from parse import PDFTextExtractor
8
+ from chunk import SimpleTextChunker
9
+ from llm import ask_llm
10
+
11
+ # Load environment variables
12
+ load_dotenv()
13
+
14
+ # Initialize VectorStore
15
+ if "store" not in st.session_state:
16
+ st.session_state["store"] = VectorStore()
17
+
18
+
19
+ st.title("📚 RAG Note Assistant - Upload & Ask")
20
+
21
+ PDF_FOLDER = "pdf_folder"
22
+ os.makedirs(PDF_FOLDER, exist_ok=True)
23
+
24
+ # upload PDF files
25
+ uploaded_files = st.file_uploader("Upload new PDF documents", accept_multiple_files=True, type=["pdf"])
26
+
27
+ if uploaded_files:
28
+ for file in uploaded_files:
29
+ file_path = os.path.join(PDF_FOLDER, file.name)
30
+ with open(file_path, "wb") as f:
31
+ f.write(file.getbuffer())
32
+
33
+ # Extract text from the uploaded PDF
34
+ extractor = PDFTextExtractor(PDF_FOLDER)
35
+ document = extractor.extract_text_from_pdf(file_path)
36
+
37
+
38
+ # Chunk the extracted text
39
+ chunker = SimpleTextChunker(chunk_size=500, chunk_overlap=100)
40
+ chunks = chunker.process_document(document)
41
+
42
+ # Generate embeddings and upsert into Pinecone
43
+ embeddings = [get_embedding(chunk["content"]) for chunk in chunks]
44
+ st.session_state["store"].add(embeddings, chunks)
45
+
46
+ st.success(f" '{file.name}' has been successfully added to the knowledge base!")
47
+
48
+ # ask question
49
+ question = st.text_input("Enter your question")
50
+
51
+ if st.button("Submit"):
52
+ if not question.strip():
53
+ st.warning(" Please enter a valid question.")
54
+ else:
55
+ # Generate query embedding
56
+ query_embedding = get_embedding(question)
57
+
58
+ # Perform similarity search
59
+ relevant_chunks = st.session_state["store"].search(query_embedding)
60
+
61
+ if not relevant_chunks:
62
+ st.warning(" No relevant content found in the knowledge base. Please upload related documents first.")
63
+ else:
64
+ # Combine retrieved chunks into context
65
+ context = "\n".join([chunk["text"] for chunk in relevant_chunks])
66
+
67
+ # Ask the LLM for the answer
68
+ with st.spinner('AI is thinking...'):
69
+ answer = ask_llm(question, context)
70
+
71
+ st.markdown("### 🤖 AI Answer")
72
+ st.write(answer)
73
+
74
+ st.markdown("### 📖 Reference Chunks")
75
+ st.write(context)
script/vector.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pinecone import Pinecone, ServerlessSpec
3
+ from dotenv import load_dotenv
4
+ import numpy as np
5
+
6
+ load_dotenv()
7
+
8
+ class VectorStore:
9
+ def __init__(self):
10
+ api_key = os.getenv("PINECONE_API_KEY")
11
+ index_name = os.getenv("PINECONE_INDEX_NAME")
12
+
13
+ # connect to Pinecone
14
+ self.pc = Pinecone(api_key=api_key)
15
+ if index_name not in self.pc.list_indexes().names():
16
+ self.pc.create_index(
17
+ name=index_name,
18
+ dimension=1024,
19
+ metric="cosine",
20
+ spec=ServerlessSpec(
21
+ cloud='aws',
22
+ region='us-east-1'
23
+ )
24
+ )
25
+ print(f" Created new Pinecone index: {index_name}")
26
+ else:
27
+ print(f"Reusing existing Pinecone index: {index_name}")
28
+
29
+
30
+ self.index = self.pc.Index(index_name)
31
+
32
+ def add(self, embeddings, chunks):
33
+ vectors = []
34
+ for idx, emb in enumerate(embeddings):
35
+ vectors.append((
36
+ f"chunk-{idx}",
37
+ emb,
38
+ {"text": chunks[idx]["content"], "source": chunks[idx]["source"], "position": chunks[idx]["chunk_index"]}
39
+ ))
40
+ self.index.upsert(vectors)
41
+
42
+ def search(self, query_embedding, top_k=5):
43
+ query_embedding = query_embedding
44
+ results = self.index.query(vector=query_embedding, top_k=top_k, include_metadata=True)
45
+ return [
46
+ {
47
+ "text": item["metadata"]["text"],
48
+ "source": item["metadata"]["source"],
49
+ "position": item["metadata"]["position"],
50
+ "score": item["score"]
51
+ }
52
+ for item in results["matches"]
53
+ ]