| import os |
| import secrets |
| from datetime import datetime, timedelta |
| from typing import List, Dict, Optional, Union |
|
|
| import chromadb |
| import json |
| from fastapi import FastAPI, HTTPException, Body, Query, Depends |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from fastapi.middleware.cors import CORSMiddleware |
| from jose import jwt, JWTError |
|
|
| from langchain_community.document_loaders import PyPDFLoader, WebBaseLoader |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_community.embeddings import HuggingFaceEmbeddings |
| from chromadb.config import Settings |
|
|
| |
| class TokenManager: |
| """Handle JWT token generation and validation.""" |
| |
| SECRET_KEY = secrets.token_hex(32) |
| ALGORITHM = "HS256" |
| ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
|
|
| @staticmethod |
| def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: |
| """Create JWT access token.""" |
| to_encode = data.copy() |
| expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15)) |
| to_encode.update({"exp": expire}) |
| return jwt.encode(to_encode, TokenManager.SECRET_KEY, algorithm=TokenManager.ALGORITHM) |
|
|
| @staticmethod |
| def decode_token(token: str) -> Dict: |
| """Decode and validate JWT token.""" |
| try: |
| return jwt.decode(token, TokenManager.SECRET_KEY, algorithms=[TokenManager.ALGORITHM]) |
| except JWTError: |
| raise HTTPException(status_code=401, detail="Invalid token") |
|
|
| import os |
| import re |
| import secrets |
| from datetime import datetime, timedelta |
| from typing import List, Dict, Optional, Union |
|
|
| import chromadb |
| import json |
| import requests |
| from bs4 import BeautifulSoup |
| from urllib.parse import urljoin, urlparse |
|
|
| import chromadb |
| import json |
| from fastapi import FastAPI, HTTPException, Body, Query, Depends |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from fastapi.middleware.cors import CORSMiddleware |
| from jose import jwt, JWTError |
|
|
| from langchain_community.document_loaders import PyPDFLoader, WebBaseLoader |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_community.embeddings import HuggingFaceEmbeddings |
| from chromadb.config import Settings |
|
|
| import os |
| import re |
| import time |
| import secrets |
| from datetime import datetime, timedelta |
| from typing import List, Dict, Optional, Union, Callable |
|
|
| import requests |
| from bs4 import BeautifulSoup |
| from urllib.parse import urljoin, urlparse |
| import logging |
|
|
| class ContentExtractorBase: |
| """Base class for content extraction strategies.""" |
| |
| @staticmethod |
| def extract(soup: BeautifulSoup) -> str: |
| """ |
| Base method to be overridden by specific extraction strategies. |
| |
| Args: |
| soup (BeautifulSoup): Parsed HTML content |
| |
| Returns: |
| str: Extracted text content |
| """ |
| raise NotImplementedError("Subclasses must implement extract method") |
|
|
| class DefaultContentExtractor(ContentExtractorBase): |
| """Default content extraction strategy.""" |
| |
| @staticmethod |
| def extract(soup: BeautifulSoup) -> str: |
| """ |
| Extract text using multiple strategies: |
| 1. Try main content containers |
| 2. Fall back to body text |
| 3. Provide clean, normalized text |
| """ |
| |
| content_selectors = [ |
| 'main', 'article', 'div.content', 'div.main-content', |
| 'section.content', 'body' |
| ] |
| |
| for selector in content_selectors: |
| content = soup.select_one(selector) |
| if content: |
| |
| for unwanted in content(['script', 'style', 'nav', 'footer', 'header', 'aside']): |
| unwanted.decompose() |
| |
| text = content.get_text(separator=' ', strip=True) |
| |
| text = re.sub(r'\s+', ' ', text).strip() |
| |
| if text and len(text) > 100: |
| return text |
| |
| |
| return soup.get_text(separator=' ', strip=True) |
|
|
| class BlogContentExtractor(ContentExtractorBase): |
| """Specialized content extractor for blog-style websites.""" |
| |
| @staticmethod |
| def extract(soup: BeautifulSoup) -> str: |
| """ |
| Extract text for blog-style content: |
| 1. Prioritize article body |
| 2. Handle common blog layouts |
| """ |
| |
| blog_selectors = [ |
| 'article .entry-content', |
| '.post-content', |
| 'div.blog-post', |
| '.article-body' |
| ] |
| |
| for selector in blog_selectors: |
| content = soup.select_one(selector) |
| if content: |
| |
| for unwanted in content(['script', 'style', 'aside', 'footer', 'header']): |
| unwanted.decompose() |
| |
| text = content.get_text(separator=' ', strip=True) |
| text = re.sub(r'\s+', ' ', text).strip() |
| |
| if text and len(text) > 100: |
| return text |
| |
| |
| return DefaultContentExtractor.extract(soup) |
|
|
| class ContentExtractorFactory: |
| """Factory for creating appropriate content extractors.""" |
| |
| EXTRACTORS = { |
| 'default': DefaultContentExtractor, |
| 'blog': BlogContentExtractor |
| } |
| |
| @classmethod |
| def get_extractor(cls, site_type: str = 'default') -> ContentExtractorBase: |
| """ |
| Get appropriate content extractor based on site type. |
| |
| Args: |
| site_type (str): Type of website content |
| |
| Returns: |
| ContentExtractorBase: Appropriate content extractor |
| """ |
| return cls.EXTRACTORS.get(site_type, DefaultContentExtractor) |
|
|
| class RateLimiter: |
| """ |
| Sophisticated rate limiting with configurable strategies. |
| |
| Supports: |
| - Maximum total requests |
| - Request frequency control |
| - Exponential backoff for failed requests |
| """ |
| |
| def __init__( |
| self, |
| max_requests: int = 50, |
| request_interval: float = 1.0, |
| max_retries: int = 3 |
| ): |
| """ |
| Initialize rate limiter. |
| |
| Args: |
| max_requests (int): Maximum number of requests allowed |
| request_interval (float): Minimum time between requests in seconds |
| max_retries (int): Maximum retry attempts for failed requests |
| """ |
| self.max_requests = max_requests |
| self.request_interval = request_interval |
| self.max_retries = max_retries |
| |
| self.request_count = 0 |
| self.last_request_time = 0 |
| |
| def wait(self): |
| """ |
| Implement wait strategy between requests. |
| Uses time-based rate limiting. |
| """ |
| current_time = time.time() |
| time_since_last_request = current_time - self.last_request_time |
| |
| if time_since_last_request < self.request_interval: |
| sleep_time = self.request_interval - time_since_last_request |
| time.sleep(sleep_time) |
| |
| def can_request(self) -> bool: |
| """ |
| Check if a new request can be made. |
| |
| Returns: |
| bool: True if request is allowed, False otherwise |
| """ |
| return ( |
| self.request_count < self.max_requests |
| ) |
| |
| def record_request(self): |
| """ |
| Record a successful request. |
| """ |
| self.request_count += 1 |
| self.last_request_time = time.time() |
| |
| def exponential_backoff(self, attempt: int) -> float: |
| """ |
| Calculate exponential backoff time. |
| |
| Args: |
| attempt (int): Current retry attempt |
| |
| Returns: |
| float: Time to wait before next retry |
| """ |
| return min(2 ** attempt, 60) |
|
|
| class DomainCrawler: |
| """ |
| Advanced domain crawler with: |
| - Rate limiting |
| - Configurable content extraction |
| - Robust URL handling |
| """ |
| |
| def __init__( |
| self, |
| base_url: str, |
| max_pages: int = 50, |
| depth: int = 3, |
| site_type: str = 'default', |
| rate_limit_config: Dict = None |
| ): |
| """ |
| Initialize domain crawler with advanced configurations. |
| |
| Args: |
| base_url (str): Starting URL to crawl |
| max_pages (int): Maximum number of pages to index |
| depth (int): Maximum depth of links to follow |
| site_type (str): Type of website for content extraction |
| rate_limit_config (Dict): Custom rate limiting configuration |
| """ |
| |
| self.base_url = self.normalize_url(base_url) |
| self.max_pages = max_pages |
| self.depth = depth |
| |
| |
| self.content_extractor = ContentExtractorFactory.get_extractor(site_type) |
| |
| |
| rate_config = rate_limit_config or { |
| 'max_requests': max_pages, |
| 'request_interval': 1.0, |
| 'max_retries': 3 |
| } |
| self.rate_limiter = RateLimiter(**rate_config) |
| |
| |
| self.visited_urls = set() |
| self.page_sources = [] |
| |
| |
| self.logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
| @staticmethod |
| def normalize_url(url: str) -> str: |
| """Normalize URLs to prevent duplicate indexing.""" |
| try: |
| parsed = urlparse(url) |
| clean_url = parsed._replace(fragment='') |
| normalized = clean_url.geturl() |
| normalized = re.sub(r'^https?://www\.', 'https://', normalized) |
| normalized = re.sub(r'/$', '', normalized) |
| return normalized.lower() |
| except Exception: |
| return url.lower() |
|
|
| def is_valid_url(self, url: str) -> bool: |
| """Check if URL is valid and within the same domain.""" |
| try: |
| normalized_url = self.normalize_url(url) |
| parsed_base = urlparse(self.base_url) |
| parsed_url = urlparse(normalized_url) |
| |
| return ( |
| parsed_base.netloc == parsed_url.netloc and |
| parsed_url.scheme in ['http', 'https'] and |
| not re.search(r'\.(pdf|jpg|jpeg|png|gif|mp4|mp3|zip|rar)$', parsed_url.path, re.IGNORECASE) and |
| normalized_url not in self.visited_urls |
| ) |
| except Exception: |
| return False |
|
|
| def extract_text(self, url: str) -> str: |
| """ |
| Extract text content with retry and rate limiting. |
| |
| Args: |
| url (str): URL to extract content from |
| |
| Returns: |
| str: Extracted text content |
| """ |
| if not self.rate_limiter.can_request(): |
| self.logger.warning(f"Rate limit exceeded. Stopping crawl at {url}") |
| return "" |
|
|
| for attempt in range(self.rate_limiter.max_retries): |
| try: |
| |
| self.rate_limiter.wait() |
| |
| |
| response = requests.get(url, timeout=10, headers={ |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' |
| }) |
| response.raise_for_status() |
| |
| |
| soup = BeautifulSoup(response.text, 'html.parser') |
| |
| |
| text = self.content_extractor.extract(soup) |
| |
| |
| self.rate_limiter.record_request() |
| |
| return text |
| |
| except requests.RequestException as e: |
| self.logger.error(f"Request error for {url} (Attempt {attempt+1}): {e}") |
| |
| |
| backoff_time = self.rate_limiter.exponential_backoff(attempt) |
| time.sleep(backoff_time) |
| |
| self.logger.error(f"Failed to extract content from {url} after {self.rate_limiter.max_retries} attempts") |
| return "" |
|
|
| def crawl(self, url: str = None, current_depth: int = 0) -> List[Dict]: |
| """ |
| Recursive web crawler with advanced controls. |
| |
| Args: |
| url (str, optional): URL to crawl. Defaults to base_url. |
| current_depth (int, optional): Current crawl depth. Defaults to 0. |
| |
| Returns: |
| List[Dict]: Collected page sources |
| """ |
| url = self.normalize_url(url or self.base_url) |
| |
| |
| if (current_depth > self.depth or |
| len(self.visited_urls) >= self.max_pages or |
| not self.is_valid_url(url)): |
| return self.page_sources |
|
|
| |
| if url in self.visited_urls: |
| return self.page_sources |
| |
| self.visited_urls.add(url) |
| |
| |
| text = self.extract_text(url) |
| if text: |
| self.page_sources.append({ |
| "text": text, |
| "source": url, |
| "indexed_at": datetime.utcnow().isoformat() |
| }) |
| |
| |
| if not self.rate_limiter.can_request(): |
| self.logger.info("Rate limit reached. Stopping crawl.") |
| return self.page_sources |
|
|
| |
| try: |
| response = requests.get(url, timeout=10, headers={ |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' |
| }) |
| soup = BeautifulSoup(response.text, 'html.parser') |
| |
| for link in soup.find_all('a', href=True): |
| absolute_link = urljoin(url, link['href']) |
| normalized_link = self.normalize_url(absolute_link) |
| |
| if (self.is_valid_url(normalized_link) and |
| normalized_link not in self.visited_urls): |
| self.crawl(normalized_link, current_depth + 1) |
| |
| except Exception as e: |
| self.logger.error(f"Crawling error for {url}: {e}") |
|
|
| return self.page_sources |
| class DocumentIndexerConfig: |
| """Enhanced configuration management.""" |
| |
| CONFIG_FILE = "server_config.json" |
| |
| DEFAULT_CONFIG = { |
| "chunk_size": 500, |
| "chunk_overlap": 50, |
| "embedding_model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", |
| "top_k": 5, |
| "persist_directory": "/home/user/.cache/chroma_db", |
| "default_collection": "documents" |
| } |
|
|
| @classmethod |
| def load_config(cls) -> Dict: |
| """Load configuration with fallback to defaults.""" |
| try: |
| with open(cls.CONFIG_FILE, "r") as f: |
| config = json.load(f) |
| |
| return {**cls.DEFAULT_CONFIG, **config} |
| except FileNotFoundError: |
| return cls.DEFAULT_CONFIG |
|
|
| @classmethod |
| def save_config(cls, config_data: Dict): |
| """Save configuration, preserving existing data.""" |
| existing_config = cls.load_config() |
| existing_config.update(config_data) |
| with open(cls.CONFIG_FILE, "w") as f: |
| json.dump(existing_config, f, indent=4) |
|
|
|
|
|
|
| class UserManager: |
| """Enhanced user management with role-based access.""" |
| |
| USERS = { |
| "admin": { |
| "password": "admin123", |
| "role": "admin", |
| "collections": ["all"] |
| }, |
| "customer": { |
| "password": "customer123", |
| "role": "user", |
| "collections": ["customer_1"] |
| } |
| } |
| |
| @staticmethod |
| def authenticate_user(username: str, password: str) -> Optional[Dict]: |
| """Authenticate user credentials.""" |
| user = UserManager.USERS.get(username) |
| if user and user["password"] == password: |
| return user |
| return None |
| |
| @staticmethod |
| def get_user_collections(username: str) -> List[str]: |
| """Get collections a user can access.""" |
| user = UserManager.USERS.get(username) |
| return user.get('collections', []) if user else [] |
| |
| class DocumentIndexer: |
| """Comprehensive document indexing with enhanced metadata and source tracking.""" |
| |
| def __init__(self, config: Dict = None, username: str = None): |
| """ |
| Initialize with configuration and username for collection management |
| """ |
| self.config = config or DocumentIndexerConfig.load_config() |
| self.username = username |
| |
| |
| if username == "admin": |
| self.default_collection = self.config["default_collection"] |
| else: |
| |
| self.default_collection = f"{username}_collection" |
| |
| self.chroma_client = chromadb.Client( |
| Settings(persist_directory=self.config["persist_directory"]) |
| ) |
| self.embedding_model = HuggingFaceEmbeddings( |
| model_name=self.config["embedding_model"] |
| ) |
|
|
| def split_document(self, text: str) -> List[str]: |
| """Enhanced text splitting with configurable parameters.""" |
| splitter = RecursiveCharacterTextSplitter( |
| chunk_size=self.config.get("chunk_size", 500), |
| chunk_overlap=self.config.get("chunk_overlap", 50) |
| ) |
| return splitter.split_text(text) |
|
|
| def index_document( |
| self, |
| text: str, |
| doc_type: str, |
| collection_name: str, |
| source: str = None |
| ) -> Dict: |
| """ |
| Advanced document indexing with comprehensive metadata. |
| |
| Args: |
| text (str): Document text content |
| doc_type (str): Type of document |
| collection_name (str): Target collection |
| source (str, optional): Source URL or path |
| """ |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) |
| chunks = self.split_document(text) |
| |
| doc_id = f"{collection_name}_{len(collection.get()['ids'])}" |
| |
| collection.add( |
| documents=chunks, |
| metadatas=[{ |
| "doc_id": doc_id or "", |
| "source_type": doc_type or "", |
| "source": source or "", |
| "indexed_at": datetime.utcnow().isoformat(), |
| "chunk_number": i, |
| "total_chunks": len(chunks) or 0 |
| } for i in range(len(chunks))], |
| ids=[f"{doc_id}_{i}" for i in range(len(chunks))] |
| ) |
| |
| return { |
| "status": "success", |
| "doc_id": doc_id, |
| "collection": collection_name, |
| "chunks": len(chunks) |
| } |
| |
|
|
| def index_source(self, source: str, doc_type: str, collection_name: str) -> Dict: |
| """ |
| Comprehensive source indexing with multiple strategies. |
| |
| Supports: |
| - Manual text entry |
| - PDF files |
| - Websites |
| - Entire domain crawling |
| """ |
| strategies = { |
| "manual": lambda: self.index_document(source, doc_type, collection_name), |
| "pdf": lambda: self.index_document( |
| PyPDFLoader(source).load_and_split_text(), |
| doc_type, |
| collection_name, |
| source |
| ), |
| "website": lambda: self.index_document( |
| WebBaseLoader(source).load(), |
| doc_type, |
| collection_name, |
| source |
| ), |
| "domain": lambda: self._index_domain(source, collection_name) |
| } |
| |
| strategy = strategies.get(doc_type) |
| if not strategy: |
| raise HTTPException(status_code=400, detail="Unsupported document type") |
| |
| return strategy() |
|
|
| def _index_domain(self, base_url: str, collection_name: str) -> Dict: |
| """ |
| Enhanced domain indexing with individual page tracking |
| """ |
| crawler = DomainCrawler(base_url) |
| page_sources = crawler.crawl() |
| |
| |
| if not page_sources: |
| return { |
| "status": "warning", |
| "message": "No text content found", |
| "collection": collection_name, |
| "chunks": 0 |
| } |
| |
| |
| results = [] |
| for page in page_sources: |
| result = self.index_document( |
| page["text"], |
| "domain", |
| collection_name, |
| page["source"] |
| ) |
| results.append(result) |
| |
| return { |
| "status": "success", |
| "collection": collection_name, |
| "total_pages_indexed": len(results), |
| "details": results |
| } |
|
|
| def search_documents( |
| self, |
| query: str, |
| top_k: int = None, |
| collection_name: str = None |
| ) -> Dict: |
| """Enhanced document search with configurable parameters.""" |
| collection_name = collection_name or self.config["default_collection"] |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) |
| |
| top_k = top_k or self.config.get("top_k", 5) |
| results = collection.query(query_texts=[query], n_results=top_k) |
| |
| return results |
|
|
| class TokenManager: |
| """Handle JWT token generation and validation.""" |
| |
| SECRET_KEY = secrets.token_hex(32) |
| ALGORITHM = "HS256" |
| ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
|
|
| @staticmethod |
| def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: |
| """Create JWT access token.""" |
| to_encode = data.copy() |
| expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15)) |
| to_encode.update({"exp": expire}) |
| return jwt.encode(to_encode, TokenManager.SECRET_KEY, algorithm=TokenManager.ALGORITHM) |
|
|
| @staticmethod |
| def decode_token(token: str) -> Dict: |
| """Decode and validate JWT token.""" |
| try: |
| return jwt.decode(token, TokenManager.SECRET_KEY, algorithms=[TokenManager.ALGORITHM]) |
| except JWTError: |
| raise HTTPException(status_code=401, detail="Invalid token") |
|
|
|
|
| def create_app(): |
| """Create and configure FastAPI application.""" |
| app = FastAPI(title="Document Indexing API",openapi_url="/api/v1/openapi.json",docs_url="/documentation") |
| |
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
| |
| security = HTTPBearer() |
| |
| def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)): |
| """Authenticate user from token.""" |
| token = credentials.credentials |
| payload = TokenManager.decode_token(token) |
| username = payload.get("sub") |
| if not username: |
| raise HTTPException(status_code=401, detail="Invalid token") |
| return username |
|
|
| @app.post("/login") |
| async def login(credentials: Dict[str, str] = Body(...)): |
| """Login endpoint with JWT token generation.""" |
| username = credentials.get("username") |
| password = credentials.get("password") |
|
|
| user = UserManager.authenticate_user(username, password) |
| if not user: |
| raise HTTPException(status_code=401, detail="Invalid credentials") |
|
|
| access_token = TokenManager.create_access_token( |
| data={"sub": username, "role": user["role"]}, |
| expires_delta=timedelta(minutes=TokenManager.ACCESS_TOKEN_EXPIRE_MINUTES) |
| ) |
| |
| return { |
| "token": access_token, |
| "username": username, |
| "role": user["role"] |
| } |
|
|
|
|
| @app.get("/admin/config") |
| def get_config(user: str = Depends(get_current_user)): |
| """Get server configuration (Admin only).""" |
| if user != "admin": |
| raise HTTPException(status_code=403, detail="Access denied") |
| return server_config |
| |
| @app.post("/admin/config") |
| def update_config(new_config: Dict, user: str = Depends(get_current_user)): |
| """Update server configuration (Admin only).""" |
| if user != "admin": |
| raise HTTPException(status_code=403, detail="Access denied") |
| server_config.update(new_config) |
| ConfigManager.save_config(server_config) |
| return {"status": "success", "updated_config": new_config} |
| |
| @app.post("/delete") |
| def delete_document( |
| doc_id: str, |
| collection: Optional[str] = Query(None), |
| user: str = Depends(get_current_user) |
| ): |
| """ |
| Delete a specific document from ChromaDB |
| |
| - Admins can delete from any collection |
| - Users can only delete from their own collection |
| """ |
| indexer = DocumentIndexer(DocumentIndexerConfig.load_config(), user) |
| collection_name = collection or indexer.default_collection |
| |
| |
| if user != "admin": |
| |
| if not collection_name.startswith(f"{user}_"): |
| raise HTTPException(status_code=403, detail="Access denied") |
| |
| collection = indexer.chroma_client.get_or_create_collection(name=collection_name) |
| |
| |
| chunks_to_delete = [ |
| chunk_id for chunk_id in collection.get()['ids'] |
| if chunk_id.startswith(doc_id) |
| ] |
| |
| if not chunks_to_delete: |
| raise HTTPException(status_code=404, detail="Document not found") |
| |
| |
| collection.delete(ids=chunks_to_delete) |
| |
| return { |
| "status": "success", |
| "deleted_chunks": len(chunks_to_delete), |
| "document_id": doc_id |
| } |
|
|
| @app.post("/reindex") |
| def reindex_document( |
| doc_id: str, |
| doc_type: str = Body(...), |
| source: str = Body(...), |
| collection: Optional[str] = Query(None), |
| user: str = Depends(get_current_user) |
| ): |
| """ |
| Reindex a specific document: |
| 1. Delete existing document chunks |
| 2. Reindex from the source |
| """ |
| indexer = DocumentIndexer(DocumentIndexerConfig.load_config(), user) |
| collection_name = collection or indexer.default_collection |
| |
| |
| if user != "admin": |
| if not collection_name.startswith(f"{user}_"): |
| raise HTTPException(status_code=403, detail="Access denied") |
| |
| |
| delete_document(doc_id, collection_name, user) |
| |
| |
| return indexer.index_source(source, doc_type, collection_name) |
|
|
| @app.post("/reindexAll") |
| def reindex_all(user: str = Depends(get_current_user)): |
| """ |
| Reindex entire database: |
| - Admin can reindex all collections |
| - Users can only reindex their own collection |
| """ |
| if user == "admin": |
| |
| collections = [ |
| collection.name for collection in |
| DocumentIndexer(DocumentIndexerConfig.load_config()).chroma_client.list_collections() |
| ] |
| else: |
| |
| collections = [f"{user}_collection"] |
| |
| results = {} |
| for collection_name in collections: |
| try: |
| |
| collection = DocumentIndexer( |
| DocumentIndexerConfig.load_config() |
| ).chroma_client.get_or_create_collection(name=collection_name) |
| |
| |
| sources = set( |
| metadata.get('source', '') for metadata in |
| collection.get()['metadatas'] if metadata.get('source') |
| ) |
| |
| |
| collection_results = [] |
| for source in sources: |
| |
| doc_type = ( |
| "pdf" if source.endswith('.pdf') else |
| "website" if source.startswith(('http://', 'https://')) else |
| "manual" |
| ) |
| |
| |
| indexer = DocumentIndexer( |
| DocumentIndexerConfig.load_config(), |
| user |
| ) |
| result = indexer.index_source(source, doc_type, collection_name) |
| collection_results.append(result) |
| |
| results[collection_name] = collection_results |
| except Exception as e: |
| results[collection_name] = {"error": str(e)} |
| |
| return results |
|
|
| @app.post("/index") |
| def index_content( |
| doc_type: str = Body(...), |
| source: str = Body(...), |
| collection: Optional[str] = Query(None), |
| user: str = Depends(get_current_user) |
| ): |
| """ |
| Index content in ChromaDB with user-specific collections |
| """ |
| indexer = DocumentIndexer(DocumentIndexerConfig.load_config(), user) |
| |
| |
| collection_name = collection or indexer.default_collection |
| |
| return indexer.index_source(source, doc_type, collection_name) |
|
|
| @app.get("/list") |
| def list_documents(user: str = Depends(get_current_user)): |
| """List all indexed documents.""" |
| |
| indexer = DocumentIndexer(DocumentIndexerConfig.load_config()) |
| collection = indexer.chroma_client.get_or_create_collection( |
| indexer.config["default_collection"] |
| ) |
| docs = collection.get() |
| |
| return docs |
|
|
| @app.post("/search") |
| def search_documents( |
| query: str, |
| top_k: Optional[int] = Query(None), |
| collection: Optional[str] = Query(None), |
| user: str = Depends(get_current_user) |
| ): |
| """Search indexed documents.""" |
| indexer = DocumentIndexer(DocumentIndexerConfig.load_config()) |
| return indexer.search_documents(query, top_k, collection) |
|
|
| return app |
|
|
| app = create_app() |
| |