Spaces:
Sleeping
Sleeping
| from datasets import load_dataset | |
| from typing import List, Dict, Optional | |
| import json | |
| class DatasetHandler: | |
| """Handles loading and searching the CGIAR agricultural dataset.""" | |
| def __init__(self, use_streaming: bool = True, max_samples: Optional[int] = None): | |
| """ | |
| Initialize dataset handler. | |
| Args: | |
| use_streaming: If True, use streaming mode (faster, doesn't download all files) | |
| max_samples: Maximum number of samples to load (None = all, for testing use smaller number) | |
| """ | |
| self.dataset = None | |
| self.loaded = False | |
| self.use_streaming = use_streaming | |
| self.max_samples = max_samples | |
| def load_dataset(self): | |
| """Load the CGIAR dataset from HuggingFace.""" | |
| if not self.loaded: | |
| try: | |
| print("Loading CGIAR dataset from HuggingFace (this may take a moment)...") | |
| if self.use_streaming: | |
| self.dataset = load_dataset( | |
| "CGIAR/gardian-ai-ready-docs", | |
| split="train", | |
| streaming=True | |
| ) | |
| print("Dataset loaded in streaming mode (lazy loading - files downloaded on-demand only)") | |
| else: | |
| if self.max_samples: | |
| self.dataset = load_dataset( | |
| "CGIAR/gardian-ai-ready-docs", | |
| split=f"train[:{self.max_samples}]" | |
| ) | |
| print(f"Dataset loaded successfully! Loaded {len(self.dataset)} documents (sample)") | |
| else: | |
| self.dataset = load_dataset("CGIAR/gardian-ai-ready-docs", split="train") | |
| print(f"Dataset loaded successfully! Total documents: {len(self.dataset)}") | |
| self.loaded = True | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| raise | |
| return self.dataset | |
| def search_by_keyword(self, keyword: str, limit: int = 5) -> List[Dict]: | |
| """ | |
| Search documents by keyword in title, abstract, or keywords. | |
| Args: | |
| keyword: Search keyword | |
| limit: Maximum number of results to return | |
| Returns: | |
| List of matching documents | |
| """ | |
| if not self.loaded: | |
| self.load_dataset() | |
| keyword_lower = keyword.lower() | |
| results = [] | |
| checked = 0 | |
| max_to_check = 300 if self.use_streaming else None | |
| consecutive_errors = 0 | |
| max_consecutive_errors = 3 | |
| try: | |
| for doc in self.dataset: | |
| try: | |
| checked += 1 | |
| # Show progress every 100 documents | |
| if checked % 100 == 0: | |
| print(f"[DATASET] Checked {checked} documents, found {len(results)} matches so far...") | |
| if max_to_check and checked > max_to_check: | |
| print(f"[DATASET] Reached search limit of {max_to_check} documents") | |
| break | |
| # Search in title | |
| title = doc.get('title', '').lower() | |
| # Search in abstract | |
| abstract = doc.get('abstract', '').lower() | |
| # Search in keywords | |
| keywords = ' '.join(doc.get('keywords', [])).lower() | |
| if keyword_lower in title or keyword_lower in abstract or keyword_lower in keywords: | |
| results.append({ | |
| 'title': doc.get('title', ''), | |
| 'abstract': doc.get('abstract', ''), | |
| 'keywords': doc.get('keywords', []), | |
| 'url': doc.get('metadata', {}).get('url', ''), | |
| 'source': doc.get('metadata', {}).get('source', ''), | |
| 'pageCount': doc.get('pageCount', 0) | |
| }) | |
| consecutive_errors = 0 # Reset on success | |
| if len(results) >= limit: | |
| break | |
| except Exception as e: | |
| consecutive_errors += 1 | |
| if consecutive_errors >= max_consecutive_errors: | |
| print(f"[DATASET] Too many consecutive errors ({consecutive_errors}), stopping search") | |
| break | |
| # Continue to next document | |
| continue | |
| except Exception as e: | |
| print(f"[DATASET] Error during search: {e}") | |
| # Return partial results if available | |
| if results: | |
| print(f"[DATASET] Found {len(results)} results after checking {checked} documents") | |
| else: | |
| print(f"[DATASET] No results found after checking {checked} documents") | |
| return results | |
| def search_by_topic(self, topic: str, limit: int = 5) -> List[Dict]: | |
| """ | |
| Search documents by agricultural topic. | |
| Args: | |
| topic: Agricultural topic (e.g., "crop management", "pest control") | |
| limit: Maximum number of results to return | |
| Returns: | |
| List of matching documents | |
| """ | |
| return self.search_by_keyword(topic, limit) | |
| def get_document_by_title(self, title: str) -> Optional[Dict]: | |
| """ | |
| Retrieve a specific document by its title. | |
| Args: | |
| title: Document title | |
| Returns: | |
| Document data or None if not found | |
| """ | |
| if not self.loaded: | |
| self.load_dataset() | |
| title_lower = title.lower() | |
| checked = 0 | |
| max_to_check = 300 if self.use_streaming else None # Very aggressive limit | |
| consecutive_errors = 0 | |
| max_consecutive_errors = 3 | |
| try: | |
| for doc in self.dataset: | |
| try: | |
| checked += 1 | |
| if max_to_check and checked > max_to_check: | |
| break | |
| if doc.get('title', '').lower() == title_lower: | |
| return { | |
| 'title': doc.get('title', ''), | |
| 'abstract': doc.get('abstract', ''), | |
| 'keywords': doc.get('keywords', []), | |
| 'chapters': doc.get('chapters', []), | |
| 'figures': doc.get('figures', []), | |
| 'url': doc.get('metadata', {}).get('url', ''), | |
| 'source': doc.get('metadata', {}).get('source', ''), | |
| 'pageCount': doc.get('pageCount', 0) | |
| } | |
| except Exception as e: | |
| consecutive_errors += 1 | |
| if consecutive_errors >= max_consecutive_errors: | |
| break | |
| continue | |
| except Exception as e: | |
| print(f"[DATASET] Error searching for document: {e}") | |
| return None | |
| def get_random_documents(self, limit: int = 3) -> List[Dict]: | |
| """ | |
| Get random documents from the dataset. | |
| Args: | |
| limit: Number of documents to return | |
| Returns: | |
| List of random documents | |
| """ | |
| if not self.loaded: | |
| self.load_dataset() | |
| import random | |
| results = [] | |
| if self.use_streaming: | |
| count = 0 | |
| for doc in self.dataset: | |
| if count >= limit: | |
| break | |
| results.append({ | |
| 'title': doc.get('title', ''), | |
| 'abstract': doc.get('abstract', ''), | |
| 'keywords': doc.get('keywords', []), | |
| 'url': doc.get('metadata', {}).get('url', ''), | |
| 'source': doc.get('metadata', {}).get('source', ''), | |
| 'pageCount': doc.get('pageCount', 0) | |
| }) | |
| count += 1 | |
| else: | |
| indices = random.sample(range(len(self.dataset)), min(limit, len(self.dataset))) | |
| for idx in indices: | |
| doc = self.dataset[idx] | |
| results.append({ | |
| 'title': doc.get('title', ''), | |
| 'abstract': doc.get('abstract', ''), | |
| 'keywords': doc.get('keywords', []), | |
| 'url': doc.get('metadata', {}).get('url', ''), | |
| 'source': doc.get('metadata', {}).get('source', ''), | |
| 'pageCount': doc.get('pageCount', 0) | |
| }) | |
| return results | |
| def format_document_summary(self, doc: Dict) -> str: | |
| """ | |
| Format a document for display in the chat. | |
| Args: | |
| doc: Document dictionary | |
| Returns: | |
| Formatted string representation | |
| """ | |
| summary = f"**Title:** {doc.get('title', 'N/A')}\n" | |
| summary += f"**Abstract:** {doc.get('abstract', 'N/A')[:500]}...\n" | |
| if doc.get('keywords'): | |
| summary += f"**Keywords:** {', '.join(doc.get('keywords', []))}\n" | |
| summary += f"**Source:** {doc.get('source', 'N/A')}\n" | |
| if doc.get('url'): | |
| summary += f"**URL:** {doc.get('url')}\n" | |
| return summary | |