llm-agent / dataset_handler.py
=
Initial commit
d40b9df
from datasets import load_dataset
from typing import List, Dict, Optional
import json
class DatasetHandler:
"""Handles loading and searching the CGIAR agricultural dataset."""
def __init__(self, use_streaming: bool = True, max_samples: Optional[int] = None):
"""
Initialize dataset handler.
Args:
use_streaming: If True, use streaming mode (faster, doesn't download all files)
max_samples: Maximum number of samples to load (None = all, for testing use smaller number)
"""
self.dataset = None
self.loaded = False
self.use_streaming = use_streaming
self.max_samples = max_samples
def load_dataset(self):
"""Load the CGIAR dataset from HuggingFace."""
if not self.loaded:
try:
print("Loading CGIAR dataset from HuggingFace (this may take a moment)...")
if self.use_streaming:
self.dataset = load_dataset(
"CGIAR/gardian-ai-ready-docs",
split="train",
streaming=True
)
print("Dataset loaded in streaming mode (lazy loading - files downloaded on-demand only)")
else:
if self.max_samples:
self.dataset = load_dataset(
"CGIAR/gardian-ai-ready-docs",
split=f"train[:{self.max_samples}]"
)
print(f"Dataset loaded successfully! Loaded {len(self.dataset)} documents (sample)")
else:
self.dataset = load_dataset("CGIAR/gardian-ai-ready-docs", split="train")
print(f"Dataset loaded successfully! Total documents: {len(self.dataset)}")
self.loaded = True
except Exception as e:
print(f"Error loading dataset: {e}")
raise
return self.dataset
def search_by_keyword(self, keyword: str, limit: int = 5) -> List[Dict]:
"""
Search documents by keyword in title, abstract, or keywords.
Args:
keyword: Search keyword
limit: Maximum number of results to return
Returns:
List of matching documents
"""
if not self.loaded:
self.load_dataset()
keyword_lower = keyword.lower()
results = []
checked = 0
max_to_check = 300 if self.use_streaming else None
consecutive_errors = 0
max_consecutive_errors = 3
try:
for doc in self.dataset:
try:
checked += 1
# Show progress every 100 documents
if checked % 100 == 0:
print(f"[DATASET] Checked {checked} documents, found {len(results)} matches so far...")
if max_to_check and checked > max_to_check:
print(f"[DATASET] Reached search limit of {max_to_check} documents")
break
# Search in title
title = doc.get('title', '').lower()
# Search in abstract
abstract = doc.get('abstract', '').lower()
# Search in keywords
keywords = ' '.join(doc.get('keywords', [])).lower()
if keyword_lower in title or keyword_lower in abstract or keyword_lower in keywords:
results.append({
'title': doc.get('title', ''),
'abstract': doc.get('abstract', ''),
'keywords': doc.get('keywords', []),
'url': doc.get('metadata', {}).get('url', ''),
'source': doc.get('metadata', {}).get('source', ''),
'pageCount': doc.get('pageCount', 0)
})
consecutive_errors = 0 # Reset on success
if len(results) >= limit:
break
except Exception as e:
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
print(f"[DATASET] Too many consecutive errors ({consecutive_errors}), stopping search")
break
# Continue to next document
continue
except Exception as e:
print(f"[DATASET] Error during search: {e}")
# Return partial results if available
if results:
print(f"[DATASET] Found {len(results)} results after checking {checked} documents")
else:
print(f"[DATASET] No results found after checking {checked} documents")
return results
def search_by_topic(self, topic: str, limit: int = 5) -> List[Dict]:
"""
Search documents by agricultural topic.
Args:
topic: Agricultural topic (e.g., "crop management", "pest control")
limit: Maximum number of results to return
Returns:
List of matching documents
"""
return self.search_by_keyword(topic, limit)
def get_document_by_title(self, title: str) -> Optional[Dict]:
"""
Retrieve a specific document by its title.
Args:
title: Document title
Returns:
Document data or None if not found
"""
if not self.loaded:
self.load_dataset()
title_lower = title.lower()
checked = 0
max_to_check = 300 if self.use_streaming else None # Very aggressive limit
consecutive_errors = 0
max_consecutive_errors = 3
try:
for doc in self.dataset:
try:
checked += 1
if max_to_check and checked > max_to_check:
break
if doc.get('title', '').lower() == title_lower:
return {
'title': doc.get('title', ''),
'abstract': doc.get('abstract', ''),
'keywords': doc.get('keywords', []),
'chapters': doc.get('chapters', []),
'figures': doc.get('figures', []),
'url': doc.get('metadata', {}).get('url', ''),
'source': doc.get('metadata', {}).get('source', ''),
'pageCount': doc.get('pageCount', 0)
}
except Exception as e:
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
break
continue
except Exception as e:
print(f"[DATASET] Error searching for document: {e}")
return None
def get_random_documents(self, limit: int = 3) -> List[Dict]:
"""
Get random documents from the dataset.
Args:
limit: Number of documents to return
Returns:
List of random documents
"""
if not self.loaded:
self.load_dataset()
import random
results = []
if self.use_streaming:
count = 0
for doc in self.dataset:
if count >= limit:
break
results.append({
'title': doc.get('title', ''),
'abstract': doc.get('abstract', ''),
'keywords': doc.get('keywords', []),
'url': doc.get('metadata', {}).get('url', ''),
'source': doc.get('metadata', {}).get('source', ''),
'pageCount': doc.get('pageCount', 0)
})
count += 1
else:
indices = random.sample(range(len(self.dataset)), min(limit, len(self.dataset)))
for idx in indices:
doc = self.dataset[idx]
results.append({
'title': doc.get('title', ''),
'abstract': doc.get('abstract', ''),
'keywords': doc.get('keywords', []),
'url': doc.get('metadata', {}).get('url', ''),
'source': doc.get('metadata', {}).get('source', ''),
'pageCount': doc.get('pageCount', 0)
})
return results
def format_document_summary(self, doc: Dict) -> str:
"""
Format a document for display in the chat.
Args:
doc: Document dictionary
Returns:
Formatted string representation
"""
summary = f"**Title:** {doc.get('title', 'N/A')}\n"
summary += f"**Abstract:** {doc.get('abstract', 'N/A')[:500]}...\n"
if doc.get('keywords'):
summary += f"**Keywords:** {', '.join(doc.get('keywords', []))}\n"
summary += f"**Source:** {doc.get('source', 'N/A')}\n"
if doc.get('url'):
summary += f"**URL:** {doc.get('url')}\n"
return summary