GitHub Actions commited on
Commit
c9622da
·
0 Parent(s):

Deploy from GitHub Actions

Browse files
.gitignore ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+ venv/
12
+ env/
13
+
14
+ # IDE
15
+ .idea/
16
+ .vscode/
17
+ *.swp
18
+ *.swo
19
+
20
+ # Project specific
21
+ chroma_db/
22
+ data/
23
+ *.gguf
24
+ *.bin
25
+
26
+ # Cache
27
+ .pytest_cache/
28
+ .mypy_cache/
29
+ .ruff_cache/
30
+
31
+ # OS
32
+ .DS_Store
33
+ Thumbs.db
34
+
35
+ # Large model files (should not be in git)
36
+ models/
37
+ *.safetensors
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FreeRAG
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # FreeRAG - Local RAG System
14
+
15
+ A modular Retrieval Augmented Generation (RAG) system powered by Phi-3.5-mini.
16
+
17
+ ## Features
18
+
19
+ - 📄 **Multi-format support**: PDF, DOCX, TXT, Markdown
20
+ - 🔍 **Semantic search**: ChromaDB vector store with sentence-transformers
21
+ - 🤖 **Local LLM**: Phi-3.5-mini running via llama-cpp
22
+ - 💬 **Interactive chat**: Ask questions about your documents
23
+ - 🎨 **Modern UI**: Clean Gradio interface
24
+
25
+ ## Usage
26
+
27
+ 1. Upload your documents using the file upload panel
28
+ 2. Wait for processing to complete
29
+ 3. Ask questions in the chat interface
30
+ 4. Get AI-powered answers with source citations
31
+
32
+ ## Tech Stack
33
+
34
+ - **LLM**: Phi-3.5-mini (GGUF via llama-cpp-python)
35
+ - **Embeddings**: sentence-transformers (all-MiniLM-L6-v2)
36
+ - **Vector Store**: ChromaDB
37
+ - **UI**: Gradio
app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio web interface for FreeRAG - designed for HuggingFace Spaces."""
2
+
3
+ import gradio as gr
4
+ from pathlib import Path
5
+ import tempfile
6
+ import os
7
+
8
+ from src.config import Config
9
+ from src.rag.pipeline import RAGPipeline
10
+
11
+
12
+ # Global pipeline instance
13
+ pipeline: RAGPipeline = None
14
+
15
+
16
+ def get_pipeline() -> RAGPipeline:
17
+ """Get or create the RAG pipeline."""
18
+ global pipeline
19
+ if pipeline is None:
20
+ pipeline = RAGPipeline(Config.default())
21
+ return pipeline
22
+
23
+
24
+ def process_files(files):
25
+ """Process uploaded files and add to vector store."""
26
+ if not files:
27
+ return "Please upload at least one file.", get_stats_text()
28
+
29
+ pipe = get_pipeline()
30
+ total_chunks = 0
31
+ processed_files = []
32
+
33
+ for file in files:
34
+ try:
35
+ # Get the file path from gradio
36
+ file_path = file.name if hasattr(file, 'name') else file
37
+ count = pipe.ingest_file(file_path)
38
+ total_chunks += count
39
+ processed_files.append(Path(file_path).name)
40
+ except Exception as e:
41
+ return f"Error processing file: {e}", get_stats_text()
42
+
43
+ return (
44
+ f"✅ Successfully processed {len(processed_files)} file(s)!\n"
45
+ f"📄 Files: {', '.join(processed_files)}\n"
46
+ f"📊 Added {total_chunks} chunks to the knowledge base.",
47
+ get_stats_text()
48
+ )
49
+
50
+
51
+ def answer_question(question, top_k, chat_history):
52
+ """Answer a question using RAG."""
53
+ if not question.strip():
54
+ return chat_history, ""
55
+
56
+ pipe = get_pipeline()
57
+
58
+ if pipe.vector_store.get_count() == 0:
59
+ response = "⚠️ No documents have been uploaded yet. Please upload some documents first."
60
+ else:
61
+ try:
62
+ result = pipe.query(question, top_k=int(top_k))
63
+ response = result["answer"]
64
+
65
+ # Add sources
66
+ if result["sources"]:
67
+ sources = [s["filename"] for s in result["sources"]]
68
+ response += f"\n\n---\n📚 *Sources: {', '.join(sources)}*"
69
+ except Exception as e:
70
+ response = f"❌ Error: {e}"
71
+
72
+ chat_history.append((question, response))
73
+ return chat_history, ""
74
+
75
+
76
+ def get_stats_text() -> str:
77
+ """Get stats as formatted text."""
78
+ pipe = get_pipeline()
79
+ stats = pipe.get_stats()
80
+ return (
81
+ f"📊 Documents: {stats['documents_count']} chunks\n"
82
+ f"🤖 Model: Phi-3.5-mini\n"
83
+ f"📐 Embeddings: {stats['embedding_model']}"
84
+ )
85
+
86
+
87
+ def clear_knowledge_base():
88
+ """Clear all documents from the vector store."""
89
+ pipe = get_pipeline()
90
+ pipe.vector_store.clear()
91
+ return "🗑️ Knowledge base cleared.", get_stats_text()
92
+
93
+
94
+ # Custom CSS for modern dark theme
95
+ custom_css = """
96
+ .gradio-container {
97
+ max-width: 1200px !important;
98
+ }
99
+ .chat-message {
100
+ padding: 12px;
101
+ border-radius: 8px;
102
+ margin: 8px 0;
103
+ }
104
+ footer {
105
+ display: none !important;
106
+ }
107
+ """
108
+
109
+ # Build Gradio interface
110
+ with gr.Blocks(
111
+ title="FreeRAG - Local RAG System",
112
+ theme=gr.themes.Soft(
113
+ primary_hue="blue",
114
+ secondary_hue="slate"
115
+ ),
116
+ css=custom_css
117
+ ) as demo:
118
+
119
+ gr.Markdown("""
120
+ # 🚀 FreeRAG
121
+ ### Local RAG System powered by Phi-3.5-mini
122
+
123
+ Upload your documents and ask questions! Everything runs locally with no data leaving your machine.
124
+ """)
125
+
126
+ with gr.Row():
127
+ # Left column - Document Upload
128
+ with gr.Column(scale=1):
129
+ gr.Markdown("### 📁 Upload Documents")
130
+
131
+ file_upload = gr.File(
132
+ label="Upload files (PDF, DOCX, TXT, MD)",
133
+ file_count="multiple",
134
+ file_types=[".pdf", ".docx", ".txt", ".md"]
135
+ )
136
+
137
+ upload_btn = gr.Button("📤 Process Documents", variant="primary")
138
+ upload_status = gr.Textbox(label="Status", lines=3, interactive=False)
139
+
140
+ gr.Markdown("### 📊 Knowledge Base Stats")
141
+ stats_display = gr.Textbox(
142
+ label="",
143
+ value=get_stats_text,
144
+ lines=3,
145
+ interactive=False,
146
+ every=5 # Refresh every 5 seconds
147
+ )
148
+
149
+ clear_btn = gr.Button("🗑️ Clear Knowledge Base", variant="secondary")
150
+
151
+ # Right column - Chat Interface
152
+ with gr.Column(scale=2):
153
+ gr.Markdown("### 💬 Ask Questions")
154
+
155
+ chatbot = gr.Chatbot(
156
+ label="Conversation",
157
+ height=400,
158
+ show_copy_button=True
159
+ )
160
+
161
+ with gr.Row():
162
+ question_input = gr.Textbox(
163
+ label="Your Question",
164
+ placeholder="Ask anything about your documents...",
165
+ scale=4,
166
+ show_label=False
167
+ )
168
+ top_k_slider = gr.Slider(
169
+ minimum=1,
170
+ maximum=10,
171
+ value=3,
172
+ step=1,
173
+ label="Sources",
174
+ scale=1
175
+ )
176
+
177
+ with gr.Row():
178
+ submit_btn = gr.Button("🔍 Ask", variant="primary", scale=2)
179
+ clear_chat_btn = gr.Button("🧹 Clear Chat", scale=1)
180
+
181
+ # Event handlers
182
+ upload_btn.click(
183
+ fn=process_files,
184
+ inputs=[file_upload],
185
+ outputs=[upload_status, stats_display]
186
+ )
187
+
188
+ submit_btn.click(
189
+ fn=answer_question,
190
+ inputs=[question_input, top_k_slider, chatbot],
191
+ outputs=[chatbot, question_input]
192
+ )
193
+
194
+ question_input.submit(
195
+ fn=answer_question,
196
+ inputs=[question_input, top_k_slider, chatbot],
197
+ outputs=[chatbot, question_input]
198
+ )
199
+
200
+ clear_btn.click(
201
+ fn=clear_knowledge_base,
202
+ outputs=[upload_status, stats_display]
203
+ )
204
+
205
+ clear_chat_btn.click(
206
+ fn=lambda: [],
207
+ outputs=[chatbot]
208
+ )
209
+
210
+ gr.Markdown("""
211
+ ---
212
+ <center>
213
+ <p style="color: gray;">
214
+ Built with 💙 using Phi-3.5-mini, ChromaDB, and Gradio |
215
+ <a href="https://github.com/yourusername/FreeRAG">GitHub</a>
216
+ </p>
217
+ </center>
218
+ """)
219
+
220
+
221
+ if __name__ == "__main__":
222
+ demo.launch()
main.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FreeRAG - A modular RAG system using Phi-3.5-mini.
2
+
3
+ CLI entrypoint for ingesting documents and querying the RAG system.
4
+ """
5
+
6
+ import typer
7
+ from rich.console import Console
8
+ from rich.panel import Panel
9
+ from rich.markdown import Markdown
10
+ from pathlib import Path
11
+
12
+ from src.config import Config
13
+ from src.rag.pipeline import RAGPipeline
14
+
15
+ app = typer.Typer(help="FreeRAG - Local RAG system with Phi-3.5-mini")
16
+ console = Console()
17
+
18
+
19
+ def get_pipeline() -> RAGPipeline:
20
+ """Get or create the RAG pipeline."""
21
+ return RAGPipeline(Config.default())
22
+
23
+
24
+ @app.command()
25
+ def ingest(
26
+ path: str = typer.Argument(..., help="Path to file or directory to ingest"),
27
+ recursive: bool = typer.Option(True, "--recursive/--no-recursive", "-r", help="Recursively search directories")
28
+ ):
29
+ """Ingest documents into the vector store."""
30
+ pipeline = get_pipeline()
31
+ path_obj = Path(path)
32
+
33
+ if not path_obj.exists():
34
+ console.print(f"[red]Error: Path not found: {path}[/red]")
35
+ raise typer.Exit(1)
36
+
37
+ with console.status("[bold green]Ingesting documents..."):
38
+ if path_obj.is_file():
39
+ count = pipeline.ingest_file(path)
40
+ else:
41
+ count = pipeline.ingest_directory(path, recursive=recursive)
42
+
43
+ console.print(Panel(f"[green]Successfully ingested {count} chunks![/green]"))
44
+
45
+
46
+ @app.command()
47
+ def query(
48
+ question: str = typer.Argument(..., help="Question to ask"),
49
+ top_k: int = typer.Option(3, "--top-k", "-k", help="Number of documents to retrieve")
50
+ ):
51
+ """Query the RAG system."""
52
+ pipeline = get_pipeline()
53
+
54
+ if pipeline.vector_store.get_count() == 0:
55
+ console.print("[yellow]Warning: No documents in vector store. Run 'ingest' first.[/yellow]")
56
+
57
+ with console.status("[bold green]Thinking..."):
58
+ result = pipeline.query(question, top_k=top_k)
59
+
60
+ # Display answer
61
+ console.print(Panel(Markdown(result["answer"]), title="[bold blue]Answer[/bold blue]"))
62
+
63
+ # Display sources
64
+ if result["sources"]:
65
+ console.print("\n[dim]Sources:[/dim]")
66
+ for src in result["sources"]:
67
+ console.print(f" • {src['filename']}")
68
+
69
+
70
+ @app.command()
71
+ def chat():
72
+ """Interactive chat mode."""
73
+ pipeline = get_pipeline()
74
+
75
+ console.print(Panel(
76
+ "[bold]FreeRAG Chat Mode[/bold]\n"
77
+ "Type your questions and press Enter.\n"
78
+ "Type 'exit' or 'quit' to stop.",
79
+ title="🤖 FreeRAG"
80
+ ))
81
+
82
+ doc_count = pipeline.vector_store.get_count()
83
+ console.print(f"[dim]Loaded {doc_count} document chunks.[/dim]\n")
84
+
85
+ while True:
86
+ try:
87
+ question = console.input("[bold blue]You:[/bold blue] ")
88
+
89
+ if question.lower() in ["exit", "quit", "q"]:
90
+ console.print("[dim]Goodbye![/dim]")
91
+ break
92
+
93
+ if not question.strip():
94
+ continue
95
+
96
+ with console.status("[bold green]Thinking..."):
97
+ answer = pipeline.chat(question)
98
+
99
+ console.print(f"[bold green]Assistant:[/bold green] {answer}\n")
100
+
101
+ except KeyboardInterrupt:
102
+ console.print("\n[dim]Goodbye![/dim]")
103
+ break
104
+
105
+
106
+ @app.command()
107
+ def stats():
108
+ """Show vector store statistics."""
109
+ pipeline = get_pipeline()
110
+ stats = pipeline.get_stats()
111
+
112
+ console.print(Panel(
113
+ f"📊 [bold]Documents:[/bold] {stats['documents_count']} chunks\n"
114
+ f"🗃️ [bold]Collection:[/bold] {stats['collection_name']}\n"
115
+ f"🤖 [bold]LLM:[/bold] {stats['model']}\n"
116
+ f"📐 [bold]Embeddings:[/bold] {stats['embedding_model']}",
117
+ title="FreeRAG Statistics"
118
+ ))
119
+
120
+
121
+ @app.command()
122
+ def clear():
123
+ """Clear the vector store."""
124
+ if typer.confirm("Are you sure you want to clear all documents?"):
125
+ pipeline = get_pipeline()
126
+ pipeline.vector_store.clear()
127
+ console.print("[green]Vector store cleared.[/green]")
128
+
129
+
130
+ if __name__ == "__main__":
131
+ app()
pyproject.toml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "freerag"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "chromadb>=0.4.22",
9
+ "gradio>=4.0.0",
10
+ "huggingface-hub>=0.20.0",
11
+ "ipykernel>=7.1.0",
12
+ "llama-cpp-python>=0.2.50",
13
+ "pypdf>=3.17.0",
14
+ "python-docx>=1.1.0",
15
+ "rich>=13.7.0",
16
+ "sentence-transformers>=2.2.2",
17
+ "typer>=0.9.0",
18
+ ]
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core Dependencies
2
+ huggingface_hub>=0.20.0
3
+ llama-cpp-python>=0.2.50
4
+
5
+ # Embeddings
6
+ sentence-transformers>=2.2.2
7
+
8
+ # Vector Store
9
+ chromadb>=0.4.22
10
+
11
+ # Document Loaders
12
+ pypdf>=3.17.0
13
+ python-docx>=1.1.0
14
+
15
+ # CLI & Utils
16
+ rich>=13.7.0
17
+ typer>=0.9.0
18
+
19
+ # Web UI (for HuggingFace Spaces)
20
+ gradio>=4.0.0
21
+
22
+ # Dev
23
+ ipykernel
src/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """FreeRAG - A modular RAG system using Phi-3.5-mini."""
2
+
3
+ from src.config import Config
4
+
5
+ __version__ = "0.1.0"
6
+ __all__ = ["Config"]
src/config.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration settings for FreeRAG."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+
6
+
7
+ @dataclass
8
+ class ModelConfig:
9
+ """LLM model configuration."""
10
+ repo_id: str = "bartowski/Phi-3.5-mini-instruct-GGUF"
11
+ filename: str = "Phi-3.5-mini-instruct-Q4_K_M.gguf"
12
+ n_ctx: int = 4096
13
+ n_threads: int = 4
14
+ max_tokens: int = 512
15
+ temperature: float = 0.7
16
+ verbose: bool = False
17
+
18
+
19
+ @dataclass
20
+ class EmbeddingConfig:
21
+ """Embedding model configuration."""
22
+ model_name: str = "all-MiniLM-L6-v2"
23
+ device: str = "cpu"
24
+
25
+
26
+ @dataclass
27
+ class VectorStoreConfig:
28
+ """Vector store configuration."""
29
+ collection_name: str = "freerag_documents"
30
+ persist_directory: str = "./chroma_db"
31
+ top_k: int = 3
32
+
33
+
34
+ @dataclass
35
+ class ChunkingConfig:
36
+ """Text chunking configuration."""
37
+ chunk_size: int = 500
38
+ chunk_overlap: int = 50
39
+
40
+
41
+ @dataclass
42
+ class Config:
43
+ """Main configuration container."""
44
+ model: ModelConfig = field(default_factory=ModelConfig)
45
+ embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
46
+ vectorstore: VectorStoreConfig = field(default_factory=VectorStoreConfig)
47
+ chunking: ChunkingConfig = field(default_factory=ChunkingConfig)
48
+ data_directory: str = "./data"
49
+
50
+ @classmethod
51
+ def default(cls) -> "Config":
52
+ """Create default configuration."""
53
+ return cls()
54
+
55
+ def ensure_directories(self) -> None:
56
+ """Ensure required directories exist."""
57
+ Path(self.data_directory).mkdir(parents=True, exist_ok=True)
58
+ Path(self.vectorstore.persist_directory).mkdir(parents=True, exist_ok=True)
src/document_loader/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Document loader module for FreeRAG."""
2
+
3
+ from src.document_loader.loader import DocumentLoader, Document
4
+ from src.document_loader.splitter import TextSplitter
5
+
6
+ __all__ = ["DocumentLoader", "Document", "TextSplitter"]
src/document_loader/loader.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Document loader for various file formats."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ from typing import List, Optional, Dict, Any
6
+
7
+
8
+ @dataclass
9
+ class Document:
10
+ """Represents a loaded document."""
11
+ content: str
12
+ metadata: Dict[str, Any] = field(default_factory=dict)
13
+
14
+ @property
15
+ def source(self) -> str:
16
+ """Get document source path."""
17
+ return self.metadata.get("source", "unknown")
18
+
19
+
20
+ class DocumentLoader:
21
+ """Load documents from various file formats."""
22
+
23
+ SUPPORTED_EXTENSIONS = {".txt", ".md", ".pdf", ".docx"}
24
+
25
+ def __init__(self):
26
+ """Initialize the document loader."""
27
+ self._pdf_loader = None
28
+ self._docx_loader = None
29
+
30
+ def load_file(self, file_path: str) -> Document:
31
+ """Load a single file.
32
+
33
+ Args:
34
+ file_path: Path to the file.
35
+
36
+ Returns:
37
+ Loaded document.
38
+
39
+ Raises:
40
+ ValueError: If file format is not supported.
41
+ FileNotFoundError: If file doesn't exist.
42
+ """
43
+ path = Path(file_path)
44
+
45
+ if not path.exists():
46
+ raise FileNotFoundError(f"File not found: {file_path}")
47
+
48
+ extension = path.suffix.lower()
49
+
50
+ if extension not in self.SUPPORTED_EXTENSIONS:
51
+ raise ValueError(
52
+ f"Unsupported file format: {extension}. "
53
+ f"Supported: {self.SUPPORTED_EXTENSIONS}"
54
+ )
55
+
56
+ content = self._load_by_extension(path, extension)
57
+
58
+ return Document(
59
+ content=content,
60
+ metadata={
61
+ "source": str(path.absolute()),
62
+ "filename": path.name,
63
+ "extension": extension
64
+ }
65
+ )
66
+
67
+ def load_directory(
68
+ self,
69
+ directory_path: str,
70
+ recursive: bool = True
71
+ ) -> List[Document]:
72
+ """Load all supported files from a directory.
73
+
74
+ Args:
75
+ directory_path: Path to the directory.
76
+ recursive: Whether to search recursively.
77
+
78
+ Returns:
79
+ List of loaded documents.
80
+ """
81
+ path = Path(directory_path)
82
+
83
+ if not path.exists():
84
+ raise FileNotFoundError(f"Directory not found: {directory_path}")
85
+
86
+ if not path.is_dir():
87
+ raise ValueError(f"Not a directory: {directory_path}")
88
+
89
+ documents = []
90
+ pattern = "**/*" if recursive else "*"
91
+
92
+ for file_path in path.glob(pattern):
93
+ if file_path.is_file() and file_path.suffix.lower() in self.SUPPORTED_EXTENSIONS:
94
+ try:
95
+ doc = self.load_file(str(file_path))
96
+ documents.append(doc)
97
+ print(f"Loaded: {file_path.name}")
98
+ except Exception as e:
99
+ print(f"Warning: Failed to load {file_path.name}: {e}")
100
+
101
+ return documents
102
+
103
+ def _load_by_extension(self, path: Path, extension: str) -> str:
104
+ """Load file content based on extension.
105
+
106
+ Args:
107
+ path: File path.
108
+ extension: File extension.
109
+
110
+ Returns:
111
+ File content as string.
112
+ """
113
+ if extension in {".txt", ".md"}:
114
+ return self._load_text(path)
115
+ elif extension == ".pdf":
116
+ return self._load_pdf(path)
117
+ elif extension == ".docx":
118
+ return self._load_docx(path)
119
+ else:
120
+ raise ValueError(f"Unknown extension: {extension}")
121
+
122
+ def _load_text(self, path: Path) -> str:
123
+ """Load plain text file."""
124
+ return path.read_text(encoding="utf-8")
125
+
126
+ def _load_pdf(self, path: Path) -> str:
127
+ """Load PDF file."""
128
+ try:
129
+ from pypdf import PdfReader
130
+ except ImportError:
131
+ raise ImportError("pypdf is required for PDF files: pip install pypdf")
132
+
133
+ reader = PdfReader(str(path))
134
+ text_parts = []
135
+
136
+ for page in reader.pages:
137
+ text = page.extract_text()
138
+ if text:
139
+ text_parts.append(text)
140
+
141
+ return "\n\n".join(text_parts)
142
+
143
+ def _load_docx(self, path: Path) -> str:
144
+ """Load DOCX file."""
145
+ try:
146
+ from docx import Document as DocxDocument
147
+ except ImportError:
148
+ raise ImportError("python-docx is required for DOCX files: pip install python-docx")
149
+
150
+ doc = DocxDocument(str(path))
151
+ paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
152
+ return "\n\n".join(paragraphs)
src/document_loader/splitter.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Text splitter for chunking documents."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional
5
+
6
+ from src.config import ChunkingConfig
7
+ from src.document_loader.loader import Document
8
+
9
+
10
+ @dataclass
11
+ class TextChunk:
12
+ """Represents a chunk of text."""
13
+ content: str
14
+ metadata: dict
15
+ chunk_index: int
16
+
17
+
18
+ class TextSplitter:
19
+ """Split text into overlapping chunks."""
20
+
21
+ def __init__(self, config: Optional[ChunkingConfig] = None):
22
+ """Initialize the text splitter.
23
+
24
+ Args:
25
+ config: Chunking configuration. Uses defaults if not provided.
26
+ """
27
+ self.config = config or ChunkingConfig()
28
+
29
+ def split_text(self, text: str, metadata: Optional[dict] = None) -> List[TextChunk]:
30
+ """Split text into chunks.
31
+
32
+ Args:
33
+ text: Text to split.
34
+ metadata: Optional metadata to attach to chunks.
35
+
36
+ Returns:
37
+ List of text chunks.
38
+ """
39
+ if not text.strip():
40
+ return []
41
+
42
+ metadata = metadata or {}
43
+ chunks = []
44
+
45
+ # Split by sentences/paragraphs first
46
+ text = text.replace("\r\n", "\n")
47
+
48
+ start = 0
49
+ chunk_index = 0
50
+
51
+ while start < len(text):
52
+ # Calculate end position
53
+ end = start + self.config.chunk_size
54
+
55
+ # If not at the end, try to break at a sentence boundary
56
+ if end < len(text):
57
+ # Look for sentence boundaries
58
+ for sep in ["\n\n", "\n", ". ", "! ", "? "]:
59
+ last_sep = text.rfind(sep, start, end)
60
+ if last_sep > start:
61
+ end = last_sep + len(sep)
62
+ break
63
+ else:
64
+ end = len(text)
65
+
66
+ chunk_text = text[start:end].strip()
67
+
68
+ if chunk_text:
69
+ chunks.append(TextChunk(
70
+ content=chunk_text,
71
+ metadata={
72
+ **metadata,
73
+ "chunk_index": chunk_index,
74
+ "start_char": start,
75
+ "end_char": end
76
+ },
77
+ chunk_index=chunk_index
78
+ ))
79
+ chunk_index += 1
80
+
81
+ # Move start with overlap
82
+ start = end - self.config.chunk_overlap
83
+ if start <= chunks[-1].metadata.get("start_char", 0) if chunks else 0:
84
+ start = end # Avoid infinite loop
85
+
86
+ return chunks
87
+
88
+ def split_documents(self, documents: List[Document]) -> List[TextChunk]:
89
+ """Split multiple documents into chunks.
90
+
91
+ Args:
92
+ documents: List of documents to split.
93
+
94
+ Returns:
95
+ List of text chunks from all documents.
96
+ """
97
+ all_chunks = []
98
+
99
+ for doc in documents:
100
+ chunks = self.split_text(doc.content, doc.metadata)
101
+ all_chunks.extend(chunks)
102
+
103
+ return all_chunks
src/embeddings/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Embeddings module for FreeRAG."""
2
+
3
+ from src.embeddings.sentence_embeddings import EmbeddingModel
4
+
5
+ __all__ = ["EmbeddingModel"]
src/embeddings/sentence_embeddings.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sentence embeddings using sentence-transformers."""
2
+
3
+ import os
4
+ # Disable TensorFlow to avoid import conflicts with transformers
5
+ os.environ.setdefault("USE_TF", "0")
6
+ os.environ.setdefault("TRANSFORMERS_NO_TF", "1")
7
+
8
+ from typing import List, Optional
9
+ from sentence_transformers import SentenceTransformer
10
+ import numpy as np
11
+
12
+ from src.config import EmbeddingConfig
13
+
14
+
15
+ class EmbeddingModel:
16
+ """Embedding model wrapper using sentence-transformers."""
17
+
18
+ def __init__(self, config: Optional[EmbeddingConfig] = None):
19
+ """Initialize the embedding model.
20
+
21
+ Args:
22
+ config: Embedding configuration. Uses defaults if not provided.
23
+ """
24
+ self.config = config or EmbeddingConfig()
25
+ self._model: Optional[SentenceTransformer] = None
26
+
27
+ @property
28
+ def model(self) -> SentenceTransformer:
29
+ """Lazy load the embedding model."""
30
+ if self._model is None:
31
+ print(f"Loading embedding model: {self.config.model_name}...")
32
+ self._model = SentenceTransformer(
33
+ self.config.model_name,
34
+ device=self.config.device
35
+ )
36
+ print("Embedding model loaded!")
37
+ return self._model
38
+
39
+ @property
40
+ def dimension(self) -> int:
41
+ """Get embedding dimension."""
42
+ return self.model.get_sentence_embedding_dimension()
43
+
44
+ def embed_text(self, text: str) -> List[float]:
45
+ """Embed a single text.
46
+
47
+ Args:
48
+ text: Text to embed.
49
+
50
+ Returns:
51
+ Embedding vector as list of floats.
52
+ """
53
+ embedding = self.model.encode(text, convert_to_numpy=True)
54
+ return embedding.tolist()
55
+
56
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
57
+ """Embed multiple texts.
58
+
59
+ Args:
60
+ texts: List of texts to embed.
61
+
62
+ Returns:
63
+ List of embedding vectors.
64
+ """
65
+ embeddings = self.model.encode(texts, convert_to_numpy=True)
66
+ return embeddings.tolist()
67
+
68
+ def __call__(self, texts: List[str]) -> List[List[float]]:
69
+ """Make the class callable for ChromaDB compatibility.
70
+
71
+ Args:
72
+ texts: List of texts to embed.
73
+
74
+ Returns:
75
+ List of embedding vectors.
76
+ """
77
+ return self.embed_documents(texts)
src/llm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """LLM module for FreeRAG."""
2
+
3
+ from src.llm.phi_model import PhiModel
4
+
5
+ __all__ = ["PhiModel"]
src/llm/phi_model.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Phi-3.5-mini model wrapper using llama-cpp-python."""
2
+
3
+ from typing import Optional, List, Dict, Any
4
+ from huggingface_hub import hf_hub_download
5
+ from llama_cpp import Llama
6
+
7
+ from src.config import ModelConfig
8
+
9
+
10
+ class PhiModel:
11
+ """Wrapper for Phi-3.5-mini model."""
12
+
13
+ def __init__(self, config: Optional[ModelConfig] = None):
14
+ """Initialize the model wrapper.
15
+
16
+ Args:
17
+ config: Model configuration. Uses defaults if not provided.
18
+ """
19
+ self.config = config or ModelConfig()
20
+ self._model: Optional[Llama] = None
21
+ self._model_path: Optional[str] = None
22
+
23
+ @property
24
+ def model(self) -> Llama:
25
+ """Lazy load the model."""
26
+ if self._model is None:
27
+ self._load_model()
28
+ return self._model
29
+
30
+ def _load_model(self) -> None:
31
+ """Download and load the model."""
32
+ print(f"Downloading model from {self.config.repo_id}...")
33
+ self._model_path = hf_hub_download(
34
+ repo_id=self.config.repo_id,
35
+ filename=self.config.filename
36
+ )
37
+
38
+ print("Loading model into memory...")
39
+ self._model = Llama(
40
+ model_path=self._model_path,
41
+ n_ctx=self.config.n_ctx,
42
+ n_threads=self.config.n_threads,
43
+ verbose=self.config.verbose
44
+ )
45
+ print("Model loaded successfully!")
46
+
47
+ def generate(self, prompt: str, max_tokens: Optional[int] = None) -> str:
48
+ """Generate text completion.
49
+
50
+ Args:
51
+ prompt: Input prompt.
52
+ max_tokens: Maximum tokens to generate.
53
+
54
+ Returns:
55
+ Generated text.
56
+ """
57
+ output = self.model(
58
+ prompt,
59
+ max_tokens=max_tokens or self.config.max_tokens,
60
+ temperature=self.config.temperature,
61
+ echo=False
62
+ )
63
+ return output["choices"][0]["text"].strip()
64
+
65
+ def chat(
66
+ self,
67
+ messages: List[Dict[str, str]],
68
+ max_tokens: Optional[int] = None
69
+ ) -> str:
70
+ """Generate chat completion.
71
+
72
+ Args:
73
+ messages: List of message dicts with 'role' and 'content'.
74
+ max_tokens: Maximum tokens to generate.
75
+
76
+ Returns:
77
+ Assistant's response.
78
+ """
79
+ output = self.model.create_chat_completion(
80
+ messages=messages,
81
+ max_tokens=max_tokens or self.config.max_tokens,
82
+ temperature=self.config.temperature
83
+ )
84
+ return output["choices"][0]["message"]["content"].strip()
85
+
86
+ def chat_with_context(
87
+ self,
88
+ query: str,
89
+ context: str,
90
+ system_prompt: Optional[str] = None
91
+ ) -> str:
92
+ """Generate response with RAG context.
93
+
94
+ Args:
95
+ query: User's question.
96
+ context: Retrieved context from documents.
97
+ system_prompt: Optional system prompt.
98
+
99
+ Returns:
100
+ Generated response.
101
+ """
102
+ if system_prompt is None:
103
+ system_prompt = (
104
+ "You are a helpful assistant. Answer the user's question based on "
105
+ "the provided context. If the context doesn't contain relevant "
106
+ "information, say so honestly. Be concise and accurate."
107
+ )
108
+
109
+ user_message = f"""Context:
110
+ {context}
111
+
112
+ Question: {query}
113
+
114
+ Please answer based on the context provided above."""
115
+
116
+ messages = [
117
+ {"role": "system", "content": system_prompt},
118
+ {"role": "user", "content": user_message}
119
+ ]
120
+
121
+ return self.chat(messages)
src/rag/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """RAG pipeline module for FreeRAG."""
2
+
3
+ from src.rag.retriever import Retriever
4
+ from src.rag.pipeline import RAGPipeline
5
+
6
+ __all__ = ["Retriever", "RAGPipeline"]
src/rag/pipeline.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main RAG pipeline orchestrating all components."""
2
+
3
+ from typing import Optional, Dict, Any
4
+
5
+ from src.config import Config
6
+ from src.llm.phi_model import PhiModel
7
+ from src.embeddings.sentence_embeddings import EmbeddingModel
8
+ from src.document_loader.loader import DocumentLoader
9
+ from src.document_loader.splitter import TextSplitter
10
+ from src.vectorstore.chroma_store import VectorStore
11
+ from src.rag.retriever import Retriever
12
+
13
+
14
+ class RAGPipeline:
15
+ """Main RAG pipeline combining all components."""
16
+
17
+ def __init__(self, config: Optional[Config] = None):
18
+ """Initialize the RAG pipeline.
19
+
20
+ Args:
21
+ config: Configuration. Uses defaults if not provided.
22
+ """
23
+ self.config = config or Config.default()
24
+ self.config.ensure_directories()
25
+
26
+ # Initialize components lazily
27
+ self._llm: Optional[PhiModel] = None
28
+ self._embedding_model: Optional[EmbeddingModel] = None
29
+ self._vector_store: Optional[VectorStore] = None
30
+ self._retriever: Optional[Retriever] = None
31
+ self._document_loader: Optional[DocumentLoader] = None
32
+ self._text_splitter: Optional[TextSplitter] = None
33
+
34
+ @property
35
+ def llm(self) -> PhiModel:
36
+ """Get LLM instance."""
37
+ if self._llm is None:
38
+ self._llm = PhiModel(self.config.model)
39
+ return self._llm
40
+
41
+ @property
42
+ def embedding_model(self) -> EmbeddingModel:
43
+ """Get embedding model instance."""
44
+ if self._embedding_model is None:
45
+ self._embedding_model = EmbeddingModel(self.config.embedding)
46
+ return self._embedding_model
47
+
48
+ @property
49
+ def vector_store(self) -> VectorStore:
50
+ """Get vector store instance."""
51
+ if self._vector_store is None:
52
+ self._vector_store = VectorStore(
53
+ self.config.vectorstore,
54
+ self.embedding_model
55
+ )
56
+ return self._vector_store
57
+
58
+ @property
59
+ def retriever(self) -> Retriever:
60
+ """Get retriever instance."""
61
+ if self._retriever is None:
62
+ self._retriever = Retriever(
63
+ self.vector_store,
64
+ top_k=self.config.vectorstore.top_k
65
+ )
66
+ return self._retriever
67
+
68
+ @property
69
+ def document_loader(self) -> DocumentLoader:
70
+ """Get document loader instance."""
71
+ if self._document_loader is None:
72
+ self._document_loader = DocumentLoader()
73
+ return self._document_loader
74
+
75
+ @property
76
+ def text_splitter(self) -> TextSplitter:
77
+ """Get text splitter instance."""
78
+ if self._text_splitter is None:
79
+ self._text_splitter = TextSplitter(self.config.chunking)
80
+ return self._text_splitter
81
+
82
+ def ingest_file(self, file_path: str) -> int:
83
+ """Ingest a single file into the vector store.
84
+
85
+ Args:
86
+ file_path: Path to the file.
87
+
88
+ Returns:
89
+ Number of chunks added.
90
+ """
91
+ print(f"Loading file: {file_path}")
92
+ document = self.document_loader.load_file(file_path)
93
+
94
+ print("Splitting into chunks...")
95
+ chunks = self.text_splitter.split_text(document.content, document.metadata)
96
+
97
+ print(f"Adding {len(chunks)} chunks to vector store...")
98
+ return self.vector_store.add_chunks(chunks)
99
+
100
+ def ingest_directory(self, directory_path: str, recursive: bool = True) -> int:
101
+ """Ingest all files from a directory.
102
+
103
+ Args:
104
+ directory_path: Path to the directory.
105
+ recursive: Whether to search recursively.
106
+
107
+ Returns:
108
+ Total number of chunks added.
109
+ """
110
+ print(f"Loading documents from: {directory_path}")
111
+ documents = self.document_loader.load_directory(directory_path, recursive)
112
+
113
+ if not documents:
114
+ print("No documents found.")
115
+ return 0
116
+
117
+ print(f"Loaded {len(documents)} documents. Splitting into chunks...")
118
+ chunks = self.text_splitter.split_documents(documents)
119
+
120
+ print(f"Adding {len(chunks)} chunks to vector store...")
121
+ return self.vector_store.add_chunks(chunks)
122
+
123
+ def query(self, question: str, top_k: Optional[int] = None) -> Dict[str, Any]:
124
+ """Query the RAG system.
125
+
126
+ Args:
127
+ question: User's question.
128
+ top_k: Number of documents to retrieve.
129
+
130
+ Returns:
131
+ Dict with answer and sources.
132
+ """
133
+ # Retrieve relevant context
134
+ context = self.retriever.retrieve_text(question, top_k)
135
+ sources = self.retriever.retrieve(question, top_k)
136
+
137
+ # Generate answer using LLM
138
+ answer = self.llm.chat_with_context(question, context)
139
+
140
+ return {
141
+ "question": question,
142
+ "answer": answer,
143
+ "context": context,
144
+ "sources": [
145
+ {
146
+ "filename": s["metadata"].get("filename", "Unknown"),
147
+ "source": s["metadata"].get("source", "Unknown"),
148
+ "distance": s.get("distance")
149
+ }
150
+ for s in sources
151
+ ]
152
+ }
153
+
154
+ def chat(self, question: str) -> str:
155
+ """Simple chat interface that returns just the answer.
156
+
157
+ Args:
158
+ question: User's question.
159
+
160
+ Returns:
161
+ Answer string.
162
+ """
163
+ result = self.query(question)
164
+ return result["answer"]
165
+
166
+ def get_stats(self) -> Dict[str, Any]:
167
+ """Get pipeline statistics.
168
+
169
+ Returns:
170
+ Dict with stats about the pipeline.
171
+ """
172
+ return {
173
+ "documents_count": self.vector_store.get_count(),
174
+ "collection_name": self.config.vectorstore.collection_name,
175
+ "model": self.config.model.repo_id,
176
+ "embedding_model": self.config.embedding.model_name
177
+ }
src/rag/retriever.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Document retriever for RAG pipeline."""
2
+
3
+ from typing import List, Dict, Any, Optional
4
+
5
+ from src.vectorstore.chroma_store import VectorStore
6
+
7
+
8
+ class Retriever:
9
+ """Retrieve relevant documents from the vector store."""
10
+
11
+ def __init__(self, vector_store: VectorStore, top_k: int = 3):
12
+ """Initialize the retriever.
13
+
14
+ Args:
15
+ vector_store: Vector store to search.
16
+ top_k: Number of documents to retrieve.
17
+ """
18
+ self.vector_store = vector_store
19
+ self.top_k = top_k
20
+
21
+ def retrieve(self, query: str, top_k: Optional[int] = None) -> List[Dict[str, Any]]:
22
+ """Retrieve relevant documents for a query.
23
+
24
+ Args:
25
+ query: User query.
26
+ top_k: Override default number of results.
27
+
28
+ Returns:
29
+ List of relevant documents with metadata.
30
+ """
31
+ return self.vector_store.search(query, top_k=top_k or self.top_k)
32
+
33
+ def retrieve_text(self, query: str, top_k: Optional[int] = None) -> str:
34
+ """Retrieve and format documents as a single context string.
35
+
36
+ Args:
37
+ query: User query.
38
+ top_k: Override default number of results.
39
+
40
+ Returns:
41
+ Formatted context string.
42
+ """
43
+ results = self.retrieve(query, top_k)
44
+
45
+ if not results:
46
+ return "No relevant documents found."
47
+
48
+ context_parts = []
49
+ for i, result in enumerate(results, 1):
50
+ source = result["metadata"].get("filename", "Unknown")
51
+ content = result["content"]
52
+ context_parts.append(f"[Source {i}: {source}]\n{content}")
53
+
54
+ return "\n\n---\n\n".join(context_parts)
src/vectorstore/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Vector store module for FreeRAG."""
2
+
3
+ from src.vectorstore.chroma_store import VectorStore
4
+
5
+ __all__ = ["VectorStore"]
src/vectorstore/chroma_store.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ChromaDB vector store implementation."""
2
+
3
+ from typing import List, Optional, Dict, Any
4
+ from pathlib import Path
5
+ import chromadb
6
+ from chromadb.config import Settings
7
+
8
+ from src.config import VectorStoreConfig
9
+ from src.embeddings.sentence_embeddings import EmbeddingModel
10
+ from src.document_loader.splitter import TextChunk
11
+
12
+
13
+ class VectorStore:
14
+ """ChromaDB-based vector store for document storage and retrieval."""
15
+
16
+ def __init__(
17
+ self,
18
+ config: Optional[VectorStoreConfig] = None,
19
+ embedding_model: Optional[EmbeddingModel] = None
20
+ ):
21
+ """Initialize the vector store.
22
+
23
+ Args:
24
+ config: Vector store configuration.
25
+ embedding_model: Embedding model for generating vectors.
26
+ """
27
+ self.config = config or VectorStoreConfig()
28
+ self.embedding_model = embedding_model or EmbeddingModel()
29
+ self._client: Optional[chromadb.Client] = None
30
+ self._collection = None
31
+
32
+ @property
33
+ def client(self) -> chromadb.Client:
34
+ """Get or create ChromaDB client."""
35
+ if self._client is None:
36
+ persist_path = Path(self.config.persist_directory)
37
+ persist_path.mkdir(parents=True, exist_ok=True)
38
+
39
+ self._client = chromadb.PersistentClient(
40
+ path=str(persist_path),
41
+ settings=Settings(anonymized_telemetry=False)
42
+ )
43
+ return self._client
44
+
45
+ @property
46
+ def collection(self):
47
+ """Get or create the collection."""
48
+ if self._collection is None:
49
+ self._collection = self.client.get_or_create_collection(
50
+ name=self.config.collection_name,
51
+ metadata={"hnsw:space": "cosine"}
52
+ )
53
+ return self._collection
54
+
55
+ def add_chunks(self, chunks: List[TextChunk]) -> int:
56
+ """Add text chunks to the vector store.
57
+
58
+ Args:
59
+ chunks: List of text chunks to add.
60
+
61
+ Returns:
62
+ Number of chunks added.
63
+ """
64
+ if not chunks:
65
+ return 0
66
+
67
+ # Prepare data for ChromaDB
68
+ documents = [chunk.content for chunk in chunks]
69
+ metadatas = [chunk.metadata for chunk in chunks]
70
+
71
+ # Generate unique IDs
72
+ existing_count = self.collection.count()
73
+ ids = [f"doc_{existing_count + i}" for i in range(len(chunks))]
74
+
75
+ # Generate embeddings
76
+ print(f"Generating embeddings for {len(chunks)} chunks...")
77
+ embeddings = self.embedding_model.embed_documents(documents)
78
+
79
+ # Add to collection
80
+ self.collection.add(
81
+ ids=ids,
82
+ documents=documents,
83
+ metadatas=metadatas,
84
+ embeddings=embeddings
85
+ )
86
+
87
+ print(f"Added {len(chunks)} chunks to vector store.")
88
+ return len(chunks)
89
+
90
+ def search(
91
+ self,
92
+ query: str,
93
+ top_k: Optional[int] = None
94
+ ) -> List[Dict[str, Any]]:
95
+ """Search for similar documents.
96
+
97
+ Args:
98
+ query: Search query.
99
+ top_k: Number of results to return.
100
+
101
+ Returns:
102
+ List of results with document, metadata, and distance.
103
+ """
104
+ top_k = top_k or self.config.top_k
105
+
106
+ # Generate query embedding
107
+ query_embedding = self.embedding_model.embed_text(query)
108
+
109
+ # Search
110
+ results = self.collection.query(
111
+ query_embeddings=[query_embedding],
112
+ n_results=top_k,
113
+ include=["documents", "metadatas", "distances"]
114
+ )
115
+
116
+ # Format results
117
+ formatted = []
118
+ if results["documents"]:
119
+ for i, doc in enumerate(results["documents"][0]):
120
+ formatted.append({
121
+ "content": doc,
122
+ "metadata": results["metadatas"][0][i] if results["metadatas"] else {},
123
+ "distance": results["distances"][0][i] if results["distances"] else None
124
+ })
125
+
126
+ return formatted
127
+
128
+ def get_count(self) -> int:
129
+ """Get the number of documents in the store."""
130
+ return self.collection.count()
131
+
132
+ def clear(self) -> None:
133
+ """Clear all documents from the collection."""
134
+ self.client.delete_collection(self.config.collection_name)
135
+ self._collection = None
136
+ print("Vector store cleared.")
uv.lock ADDED
The diff for this file is too large to render. See raw diff