Spaces:
Sleeping
Sleeping
| import sqlite3 | |
| from typing import List, Optional, Dict, Any | |
| from langchain.schema import Document | |
| from langchain.document_loaders.base import BaseLoader | |
| import logging | |
| import re | |
| from datetime import datetime | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class DiaryDataLoader(BaseLoader): | |
| """ | |
| Custom LangChain document loader for diary entries from SQLite database. | |
| Enhanced with detailed metadata extraction for better indexing. | |
| """ | |
| def __init__( | |
| self, | |
| db_path: str, | |
| table_name: str = "diary_entries", | |
| content_column: str = "content", | |
| date_column: str = "date", | |
| tags_column: str = "tags", | |
| id_column: str = "id", | |
| user_id: int = 1 | |
| ): | |
| """ | |
| Initialize the DiaryDataLoader. | |
| Args: | |
| db_path (str): Path to the SQLite database file | |
| table_name (str): Name of the table containing diary entries | |
| content_column (str): Name of the column containing diary content | |
| date_column (str): Name of the column containing entry dates | |
| tags_column (str): Name of the column containing entry tags | |
| id_column (str): Name of the column containing entry IDs | |
| user_id (int): ID of the user for filtering diary entries | |
| """ | |
| self.db_path = db_path | |
| self.table_name = table_name | |
| self.content_column = content_column | |
| self.date_column = date_column | |
| self.tags_column = tags_column | |
| self.id_column = id_column | |
| self.user_id = user_id | |
| def _extract_tags_from_content(self, content: str) -> List[str]: | |
| """ | |
| Extract #tags from content string. | |
| Args: | |
| content: The diary content string | |
| Returns: | |
| List of tags found (without # symbol) | |
| """ | |
| if not content: | |
| return [] | |
| # Find all #tags in content | |
| tag_pattern = r'#(\w+(?:[_-]\w+)*)' | |
| matches = re.findall(tag_pattern, content, re.IGNORECASE) | |
| # Remove duplicates and return lowercase tags | |
| return list(set([tag.lower() for tag in matches])) | |
| def _extract_location_from_content(self, content: str) -> Optional[str]: | |
| """ | |
| Extract location information from content using common patterns. | |
| Args: | |
| content: The diary content string | |
| Returns: | |
| Location string if found, None otherwise | |
| """ | |
| if not content: | |
| return None | |
| # Common location patterns | |
| location_patterns = [ | |
| r'at\s+([A-Z][a-zA-Z\s]+(?:Park|Beach|Mall|Store|Restaurant|Cafe|Office|Home|School|University))', | |
| r'in\s+([A-Z][a-zA-Z\s]+(?:City|District|Area|Street|Road))', | |
| r'went\s+to\s+([A-Z][a-zA-Z\s]+)', | |
| r'visited\s+([A-Z][a-zA-Z\s]+)', | |
| r'location:\s*([A-Za-z\s]+)', | |
| r'place:\s*([A-Za-z\s]+)' | |
| ] | |
| for pattern in location_patterns: | |
| matches = re.findall(pattern, content, re.IGNORECASE) | |
| if matches: | |
| return matches[0].strip() | |
| return None | |
| def _extract_people_from_content(self, content: str) -> List[str]: | |
| """ | |
| Extract people/relationships mentioned in content. | |
| Args: | |
| content: The diary content string | |
| Returns: | |
| List of people/relationships mentioned | |
| """ | |
| if not content: | |
| return [] | |
| # Common relationship patterns | |
| people_patterns = [ | |
| r'with\s+(my\s+)?(\w+(?:\s+\w+)?)', | |
| r'(mom|dad|mother|father|sister|brother|friend|colleague|boss|teacher)', | |
| r'(family|friends|team|colleagues)', | |
| r'met\s+([\w\s]+)', | |
| r'talked\s+to\s+([\w\s]+)' | |
| ] | |
| people = set() | |
| for pattern in people_patterns: | |
| matches = re.findall(pattern, content, re.IGNORECASE) | |
| for match in matches: | |
| if isinstance(match, tuple): | |
| for part in match: | |
| if part.strip(): | |
| people.add(part.strip().lower()) | |
| else: | |
| people.add(match.strip().lower()) | |
| # Filter out common words that are not people | |
| exclude_words = {'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'} | |
| people = [p for p in people if p not in exclude_words and len(p) > 2] | |
| return list(people) | |
| def _get_day_of_week(self, date_str: str) -> str: | |
| """ | |
| Get day of week from date string. | |
| Args: | |
| date_str: Date string in YYYY-MM-DD format | |
| Returns: | |
| Day of week (e.g., 'Monday', 'Tuesday', etc.) | |
| """ | |
| try: | |
| date_obj = datetime.strptime(date_str, '%Y-%m-%d') | |
| return date_obj.strftime('%A') | |
| except: | |
| return 'Unknown' | |
| def _extract_content_from_structured_format(self, raw_content: str) -> tuple: | |
| """ | |
| Extract actual content from structured format like: | |
| Title: xxxx | |
| Type: Text | |
| Content: actual content here | |
| Returns: | |
| tuple: (title, actual_content) | |
| """ | |
| lines = raw_content.strip().split('\n') | |
| title = "" | |
| content = "" | |
| for line in lines: | |
| if line.startswith("Title: "): | |
| title = line.replace("Title: ", "").strip() | |
| elif line.startswith("Content: "): | |
| content = line.replace("Content: ", "").strip() | |
| # If no structured format found, return original content | |
| if not content: | |
| content = raw_content | |
| return title, content | |
| def load(self) -> List[Document]: | |
| """ | |
| Load diary entries from the database and convert them to LangChain Documents. | |
| Returns: | |
| List[Document]: List of LangChain Document objects | |
| """ | |
| documents = [] | |
| try: | |
| # Connect to the SQLite database | |
| conn = sqlite3.connect(self.db_path) | |
| conn.row_factory = sqlite3.Row # Enable accessing columns by name | |
| cursor = conn.cursor() | |
| # Build the SQL query with all required columns | |
| columns = [self.id_column, self.date_column, self.content_column, self.tags_column] | |
| query = f"SELECT {', '.join(columns)} FROM {self.table_name} WHERE user_id = ? ORDER BY {self.date_column} DESC" | |
| # Execute the query | |
| cursor.execute(query, (self.user_id,)) | |
| rows = cursor.fetchall() | |
| logger.info(f"Loaded {len(rows)} diary entries from database") | |
| # Convert each row to a LangChain Document with enhanced metadata | |
| for row in rows: | |
| row_dict = dict(row) if hasattr(row, 'keys') else { | |
| self.id_column: row[0], | |
| self.date_column: row[1], | |
| self.content_column: row[2], | |
| self.tags_column: row[3] if len(row) > 3 else "" | |
| } | |
| raw_content = row_dict[self.content_column] | |
| date = row_dict[self.date_column] | |
| entry_id = row_dict.get(self.id_column, "unknown") | |
| db_tags = row_dict.get(self.tags_column, "") | |
| # Extract structured content | |
| title, actual_content = self._extract_content_from_structured_format(raw_content) | |
| # Extract comprehensive metadata | |
| content_tags = self._extract_tags_from_content(actual_content) | |
| db_tag_list = [tag.strip() for tag in db_tags.split(',') if tag.strip()] if db_tags else [] | |
| all_tags = list(set(content_tags + db_tag_list)) # Combine and deduplicate | |
| location = self._extract_location_from_content(actual_content) | |
| people = self._extract_people_from_content(actual_content) | |
| day_of_week = self._get_day_of_week(date) | |
| # Create comprehensive metadata for the document | |
| metadata = { | |
| "source": self.db_path, | |
| "entry_id": str(entry_id), | |
| "date": date, | |
| "day_of_week": day_of_week, | |
| "type": "diary_entry", | |
| "tags": all_tags, | |
| "tag_count": len(all_tags), | |
| "content_length": len(actual_content), | |
| "word_count": len(actual_content.split()) | |
| } | |
| # Add optional metadata if available | |
| if title: | |
| metadata["title"] = title | |
| if location: | |
| metadata["location"] = location | |
| if people: | |
| metadata["people"] = people | |
| metadata["people_count"] = len(people) | |
| # Add mood/sentiment tags if present | |
| mood_tags = [tag for tag in all_tags if tag in ['happy', 'sad', 'excited', 'tired', 'angry', 'peaceful', 'stressed', 'grateful', 'frustrated', 'motivated']] | |
| if mood_tags: | |
| metadata["mood_tags"] = mood_tags | |
| # Create Document object with actual content | |
| document = Document( | |
| page_content=actual_content, | |
| metadata=metadata | |
| ) | |
| documents.append(document) | |
| conn.close() | |
| logger.info(f"Successfully converted {len(documents)} entries to Documents") | |
| except sqlite3.Error as e: | |
| logger.error(f"Database error: {e}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error loading diary data: {e}") | |
| raise | |
| return documents | |
| def load_by_date_range(self, start_date: str, end_date: str) -> List[Document]: | |
| """ | |
| Load diary entries within a specific date range. | |
| Args: | |
| start_date (str): Start date in YYYY-MM-DD format | |
| end_date (str): End date in YYYY-MM-DD format | |
| Returns: | |
| List[Document]: Filtered list of Document objects | |
| """ | |
| documents = [] | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| columns = [self.content_column, self.date_column] | |
| # if self.title_column: | |
| # columns.append(self.title_column) | |
| query = f""" | |
| SELECT {', '.join(columns)} | |
| FROM {self.table_name} | |
| WHERE user_id = ? AND {self.date_column} BETWEEN ? AND ? | |
| ORDER BY {self.date_column} | |
| """ | |
| cursor.execute(query, (self.user_id, start_date, end_date)) | |
| rows = cursor.fetchall() | |
| logger.info(f"Loaded {len(rows)} diary entries from {start_date} to {end_date}") | |
| for row in rows: | |
| raw_content = row[self.content_column] | |
| date = row[self.date_column] | |
| # Extract structured content | |
| title, actual_content = self._extract_content_from_structured_format(raw_content) | |
| metadata = { | |
| "source": self.db_path, | |
| "date": date, | |
| "type": "diary_entry", | |
| "date_range": f"{start_date}_to_{end_date}" | |
| } | |
| # Add title to metadata if available | |
| if title: | |
| metadata["title"] = title | |
| document = Document( | |
| page_content=actual_content, | |
| metadata=metadata | |
| ) | |
| documents.append(document) | |
| conn.close() | |
| except sqlite3.Error as e: | |
| logger.error(f"Database error: {e}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error loading diary data by date range: {e}") | |
| raise | |
| return documents | |
| def get_table_info(self) -> dict: | |
| """ | |
| Get information about the database table structure. | |
| Returns: | |
| dict: Table information including columns and row count | |
| """ | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| # Get table schema | |
| cursor.execute(f"PRAGMA table_info({self.table_name})") | |
| columns = cursor.fetchall() | |
| # Get row count | |
| cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") | |
| row_count = cursor.fetchone()[0] | |
| conn.close() | |
| return { | |
| "table_name": self.table_name, | |
| "columns": [{"name": col[1], "type": col[2]} for col in columns], | |
| "row_count": row_count | |
| } | |
| except sqlite3.Error as e: | |
| logger.error(f"Database error: {e}") | |
| raise | |
| class DiaryContentPreprocessor: | |
| """ | |
| Preprocessor for diary content to clean and standardize text before indexing. | |
| """ | |
| def __init__( | |
| self, | |
| remove_extra_whitespace: bool = True, | |
| normalize_line_breaks: bool = True, | |
| min_content_length: int = 10, | |
| max_content_length: Optional[int] = None | |
| ): | |
| """ | |
| Initialize the content preprocessor. | |
| Args: | |
| remove_extra_whitespace (bool): Remove extra spaces and tabs | |
| normalize_line_breaks (bool): Normalize line breaks to single newlines | |
| min_content_length (int): Minimum content length to keep | |
| max_content_length (int, optional): Maximum content length to keep | |
| """ | |
| self.remove_extra_whitespace = remove_extra_whitespace | |
| self.normalize_line_breaks = normalize_line_breaks | |
| self.min_content_length = min_content_length | |
| self.max_content_length = max_content_length | |
| def preprocess_content(self, content: str) -> str: | |
| """ | |
| Preprocess diary content text. | |
| Args: | |
| content (str): Raw diary content | |
| Returns: | |
| str: Preprocessed content | |
| """ | |
| if not content or not isinstance(content, str): | |
| return "" | |
| processed_content = content | |
| # Remove extra whitespace | |
| if self.remove_extra_whitespace: | |
| processed_content = ' '.join(processed_content.split()) | |
| # Normalize line breaks | |
| if self.normalize_line_breaks: | |
| processed_content = processed_content.replace('\r\n', '\n').replace('\r', '\n') | |
| # Remove multiple consecutive newlines | |
| processed_content = re.sub(r'\n+', '\n', processed_content) | |
| # Strip leading/trailing whitespace | |
| processed_content = processed_content.strip() | |
| # Check length constraints | |
| if len(processed_content) < self.min_content_length: | |
| logger.warning(f"Content too short ({len(processed_content)} chars), skipping") | |
| return "" | |
| if self.max_content_length and len(processed_content) > self.max_content_length: | |
| logger.warning(f"Content too long ({len(processed_content)} chars), truncating") | |
| processed_content = processed_content[:self.max_content_length] | |
| return processed_content | |
| def preprocess_documents(self, documents: List[Document]) -> List[Document]: | |
| """ | |
| Preprocess a list of Document objects. | |
| Args: | |
| documents (List[Document]): List of documents to preprocess | |
| Returns: | |
| List[Document]: List of preprocessed documents | |
| """ | |
| preprocessed_docs = [] | |
| for doc in documents: | |
| processed_content = self.preprocess_content(doc.page_content) | |
| # Skip empty content after preprocessing | |
| if not processed_content: | |
| continue | |
| # Create new document with processed content | |
| preprocessed_doc = Document( | |
| page_content=processed_content, | |
| metadata=doc.metadata.copy() | |
| ) | |
| preprocessed_docs.append(preprocessed_doc) | |
| logger.info(f"Preprocessed {len(documents)} documents, kept {len(preprocessed_docs)}") | |
| return preprocessed_docs | |
| def load_all_entries(self, user_id: int = None) -> List[Dict[str, Any]]: | |
| """ | |
| Load all diary entries for a specific user. | |
| Args: | |
| user_id: User ID to filter entries | |
| Returns: | |
| List of diary entry dictionaries | |
| """ | |
| if user_id is None: | |
| user_id = self.user_id | |
| entries = [] | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| query = f""" | |
| SELECT id, user_id, date, content, tags, created_at | |
| FROM {self.table_name} | |
| WHERE user_id = ? | |
| ORDER BY date DESC, created_at DESC | |
| """ | |
| cursor.execute(query, (user_id,)) | |
| rows = cursor.fetchall() | |
| for row in rows: | |
| entries.append({ | |
| 'id': row['id'], | |
| 'user_id': row['user_id'], | |
| 'date': row['date'], | |
| 'content': row['content'], | |
| 'tags': row['tags'] or '', | |
| 'created_at': row['created_at'] | |
| }) | |
| conn.close() | |
| logger.info(f"Loaded {len(entries)} entries for user {user_id}") | |
| except sqlite3.Error as e: | |
| logger.error(f"Database error loading entries: {e}") | |
| return entries | |
| def load_entries_since(self, since_date, user_id: int = None) -> List[Dict[str, Any]]: | |
| """ | |
| Load diary entries since a specific date. | |
| Args: | |
| since_date: datetime object or ISO string | |
| user_id: User ID to filter entries | |
| Returns: | |
| List of diary entry dictionaries | |
| """ | |
| if user_id is None: | |
| user_id = self.user_id | |
| entries = [] | |
| try: | |
| # Convert datetime to string if needed | |
| if hasattr(since_date, 'isoformat'): | |
| since_str = since_date.isoformat() | |
| else: | |
| since_str = str(since_date) | |
| conn = sqlite3.connect(self.db_path) | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| query = f""" | |
| SELECT id, user_id, date, content, tags, created_at | |
| FROM {self.table_name} | |
| WHERE user_id = ? AND created_at > ? | |
| ORDER BY date DESC, created_at DESC | |
| """ | |
| cursor.execute(query, (user_id, since_str)) | |
| rows = cursor.fetchall() | |
| for row in rows: | |
| entries.append({ | |
| 'id': row['id'], | |
| 'user_id': row['user_id'], | |
| 'date': row['date'], | |
| 'content': row['content'], | |
| 'tags': row['tags'] or '', | |
| 'created_at': row['created_at'] | |
| }) | |
| conn.close() | |
| logger.info(f"Loaded {len(entries)} entries since {since_str} for user {user_id}") | |
| except sqlite3.Error as e: | |
| logger.error(f"Database error loading entries since {since_date}: {e}") | |
| return entries | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Initialize the loader | |
| loader = DiaryDataLoader( | |
| db_path="../streamlit_app/backend/diary.db", | |
| table_name="diary_entries", | |
| content_column="content", | |
| date_column="date" #, | |
| # title_column="title" | |
| ) | |
| # Load all documents | |
| documents = loader.load() | |
| print(f"Loaded {len(documents)} diary entries") | |
| # Load documents by date range | |
| filtered_docs = loader.load_by_date_range("2024-01-01", "2026-12-31") | |
| print(f"Loaded {len(filtered_docs)} entries from 2024") | |
| # Get table information | |
| table_info = loader.get_table_info() | |
| print(f"Table info: {table_info}") | |
| # view document contents | |
| for doc in documents: | |
| print(f"Document content: {doc.page_content}") | |