Spaces:
Sleeping
Sleeping
| """Pinecone client for uploading and managing standard records.""" | |
| from __future__ import annotations | |
| import time | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from collections.abc import Callable | |
| from typing import Any | |
| from loguru import logger | |
| from pinecone import Pinecone | |
| from pinecone.exceptions import PineconeException | |
| from src.mcp_config import get_mcp_settings | |
| from tools.pinecone_models import PineconeRecord | |
| settings = get_mcp_settings() | |
| class PineconeClient: | |
| """Client for interacting with Pinecone index.""" | |
| def __init__(self) -> None: | |
| """Initialize Pinecone SDK from config settings.""" | |
| api_key = settings.pinecone_api_key | |
| if not api_key: | |
| raise ValueError("PINECONE_API_KEY environment variable not set") | |
| self.pc = Pinecone(api_key=api_key) | |
| self.index_name = settings.pinecone_index_name | |
| self.namespace = settings.pinecone_namespace | |
| self._index = None | |
| def index(self): | |
| """Get the index object, creating it if needed.""" | |
| if self._index is None: | |
| self._index = self.pc.Index(self.index_name) | |
| return self._index | |
| def validate_index(self) -> None: | |
| """ | |
| Check index exists with pc.has_index(), raise helpful error if not. | |
| Raises: | |
| ValueError: If index does not exist, with instructions to create it. | |
| """ | |
| if not self.pc.has_index(name=self.index_name): | |
| raise ValueError( | |
| f"Index '{self.index_name}' not found. Create it with:\n" | |
| f"pc index create -n {self.index_name} -m cosine -c aws -r us-east-1 " | |
| f"--model llama-text-embed-v2 --field_map text=content" | |
| ) | |
| def ensure_index_exists(self) -> bool: | |
| """ | |
| Check if index exists, create it if not. | |
| Creates the index with integrated embeddings using llama-text-embed-v2 model. | |
| Returns: | |
| True if index was created, False if it already existed. | |
| """ | |
| if self.pc.has_index(name=self.index_name): | |
| logger.info(f"Index '{self.index_name}' already exists") | |
| return False | |
| logger.info(f"Creating index '{self.index_name}' with integrated embeddings...") | |
| self.pc.create_index_for_model( | |
| name=self.index_name, | |
| cloud="aws", | |
| region="us-east-1", | |
| embed={ | |
| "model": "llama-text-embed-v2", | |
| "field_map": {"text": "content"}, | |
| }, | |
| ) | |
| logger.info(f"Successfully created index '{self.index_name}'") | |
| return True | |
| def get_index_stats(self) -> dict[str, Any]: | |
| """ | |
| Get index statistics including vector count and namespaces. | |
| Returns: | |
| Dictionary with index stats including total_vector_count and namespaces. | |
| """ | |
| stats = self.index.describe_index_stats() | |
| return { | |
| "total_vector_count": stats.total_vector_count, | |
| "namespaces": dict(stats.namespaces) if stats.namespaces else {}, | |
| } | |
| def exponential_backoff_retry( | |
| func: Callable[[], Any], max_retries: int = 5 | |
| ) -> Any: | |
| """ | |
| Retry function with exponential backoff on 429/5xx, fail on 4xx. | |
| Args: | |
| func: Function to retry (should be a callable that takes no args) | |
| max_retries: Maximum number of retry attempts | |
| Returns: | |
| Result of func() | |
| Raises: | |
| PineconeException: If retries exhausted or non-retryable error | |
| """ | |
| for attempt in range(max_retries): | |
| try: | |
| return func() | |
| except PineconeException as e: | |
| status_code = getattr(e, "status", None) | |
| # Only retry transient errors | |
| if status_code and (status_code >= 500 or status_code == 429): | |
| if attempt < max_retries - 1: | |
| delay = min(2 ** attempt, 60) # Cap at 60s | |
| logger.warning( | |
| f"Retryable error (status {status_code}), " | |
| f"retrying in {delay}s (attempt {attempt + 1}/{max_retries})" | |
| ) | |
| time.sleep(delay) | |
| else: | |
| logger.error( | |
| f"Max retries ({max_retries}) exceeded for retryable error" | |
| ) | |
| raise | |
| else: | |
| # Don't retry client errors | |
| logger.error(f"Non-retryable error (status {status_code}): {e}") | |
| raise | |
| except Exception as e: | |
| # Non-Pinecone exceptions should not be retried | |
| logger.error(f"Non-retryable exception: {e}") | |
| raise | |
| def batch_upsert( | |
| self, records: list[PineconeRecord], batch_size: int = 96 | |
| ) -> None: | |
| """ | |
| Upsert records in batches of specified size with rate limiting. | |
| Args: | |
| records: List of PineconeRecord objects to upsert | |
| batch_size: Number of records per batch (default: 96) | |
| """ | |
| if not records: | |
| logger.info("No records to upsert") | |
| return | |
| total_batches = (len(records) + batch_size - 1) // batch_size | |
| logger.info( | |
| f"Upserting {len(records)} records in {total_batches} batch(es) " | |
| f"(batch size: {batch_size})" | |
| ) | |
| for i in range(0, len(records), batch_size): | |
| batch = records[i : i + batch_size] | |
| batch_num = (i // batch_size) + 1 | |
| # Convert PineconeRecord models to dict format for Pinecone | |
| batch_dicts = [self._record_to_dict(record) for record in batch] | |
| logger.debug(f"Upserting batch {batch_num}/{total_batches} ({len(batch)} records)") | |
| # Retry with exponential backoff | |
| self.exponential_backoff_retry( | |
| lambda b=batch_dicts: self.index.upsert_records( | |
| namespace=self.namespace, records=b | |
| ) | |
| ) | |
| # Rate limiting between batches | |
| if i + batch_size < len(records): | |
| time.sleep(0.1) | |
| logger.info(f"Successfully upserted {len(records)} records") | |
| def _record_to_dict(record: PineconeRecord) -> dict[str, Any]: | |
| """ | |
| Convert PineconeRecord model to dict format for Pinecone API. | |
| Handles optional fields by omitting them if None. Pinecone doesn't accept | |
| null values for metadata fields, so parent_id must be omitted entirely | |
| when None (for root nodes). | |
| Args: | |
| record: PineconeRecord model instance | |
| Returns: | |
| Dictionary ready for Pinecone upsert_records | |
| """ | |
| # Use by_alias=True to serialize 'id' as '_id' per model serialization_alias | |
| record_dict = record.model_dump(exclude_none=False, by_alias=True) | |
| # Remove None values for optional fields | |
| optional_fields = { | |
| "asn_identifier", | |
| "statement_notation", | |
| "statement_label", | |
| "normalized_subject", | |
| "publication_status", | |
| "parent_id", # Must be omitted when None (Pinecone doesn't accept null) | |
| "document_id", | |
| "document_valid", | |
| } | |
| for field in optional_fields: | |
| if record_dict.get(field) is None: | |
| record_dict.pop(field, None) | |
| return record_dict | |
| def search_standards( | |
| self, | |
| query_text: str, | |
| top_k: int = 5, | |
| grade: str | None = None, | |
| ) -> list[dict]: | |
| """ | |
| Perform semantic search over standards. | |
| Args: | |
| query_text: Natural language query | |
| top_k: Maximum number of results | |
| grade: Optional grade filter | |
| Returns: | |
| List of result dictionaries with metadata and scores | |
| """ | |
| # Build filter dictionary dynamically | |
| # Always filter to only leaf nodes (actual standards, not parent categories) | |
| filter_parts = [{"is_leaf": {"$eq": True}}] | |
| if grade: | |
| filter_parts.append({"education_levels": {"$in": [grade]}}) | |
| filter_dict = None | |
| if len(filter_parts) == 1: | |
| filter_dict = filter_parts[0] | |
| elif len(filter_parts) == 2: | |
| filter_dict = {"$and": filter_parts} | |
| # Build query dictionary | |
| query_dict: dict[str, Any] = { | |
| "inputs": {"text": query_text}, | |
| "top_k": top_k * 2, # Get more candidates for reranking | |
| } | |
| if filter_dict: | |
| query_dict["filter"] = filter_dict | |
| # Call search with reranking | |
| results = self.index.search( | |
| namespace=self.namespace, | |
| query=query_dict, | |
| rerank={"model": "bge-reranker-v2-m3", "top_n": top_k, "rank_fields": ["content"]}, | |
| ) | |
| # Parse results | |
| hits = results.get("result", {}).get("hits", []) | |
| parsed_results = [] | |
| for hit in hits: | |
| result_dict = { | |
| "_id": hit["_id"], | |
| "score": hit["_score"], | |
| **hit.get("fields", {}), | |
| } | |
| parsed_results.append(result_dict) | |
| return parsed_results | |
| def fetch_standard(self, standard_id: str) -> dict | None: | |
| """ | |
| Fetch a standard by its GUID (_id field only). | |
| This method performs a direct lookup using Pinecone's fetch() API, which only | |
| works with the standard's GUID (_id field). It does NOT search by statement_notation, | |
| asn_identifier, or any other metadata fields. | |
| Args: | |
| standard_id: Standard GUID (_id field) - must be the exact GUID format | |
| (e.g., "EA60C8D165F6481B90BFF782CE193F93") | |
| Returns: | |
| Standard dictionary with metadata, or None if not found | |
| """ | |
| result = self.index.fetch(ids=[standard_id], namespace=self.namespace) | |
| # Extract vectors from FetchResponse | |
| # FetchResponse.vectors is a dict mapping ID to Vector objects | |
| vectors = result.vectors | |
| if not vectors or standard_id not in vectors: | |
| return None | |
| vector = vectors[standard_id] | |
| # Extract metadata from Vector object | |
| # Vector has: id, values (embedding), and metadata (dict with all fields) | |
| metadata = vector.metadata or {} | |
| vector_id = vector.id | |
| # Combine _id with all metadata fields | |
| record_dict = { | |
| "_id": vector_id, | |
| **metadata, | |
| } | |
| return record_dict | |
| def is_uploaded(set_dir: Path) -> bool: | |
| """ | |
| Check for .pinecone_uploaded marker file. | |
| Args: | |
| set_dir: Path to standard set directory | |
| Returns: | |
| True if marker file exists, False otherwise | |
| """ | |
| marker_file = set_dir / ".pinecone_uploaded" | |
| return marker_file.exists() | |
| def mark_uploaded(set_dir: Path) -> None: | |
| """ | |
| Create marker file with ISO 8601 timestamp. | |
| Args: | |
| set_dir: Path to standard set directory | |
| """ | |
| marker_file = set_dir / ".pinecone_uploaded" | |
| timestamp = datetime.now(timezone.utc).isoformat() | |
| marker_file.write_text(timestamp, encoding="utf-8") | |
| logger.debug(f"Created upload marker: {marker_file}") | |
| def get_upload_timestamp(set_dir: Path) -> str | None: | |
| """ | |
| Read timestamp from marker file. | |
| Args: | |
| set_dir: Path to standard set directory | |
| Returns: | |
| ISO 8601 timestamp string if marker exists, None otherwise | |
| """ | |
| marker_file = set_dir / ".pinecone_uploaded" | |
| if not marker_file.exists(): | |
| return None | |
| try: | |
| return marker_file.read_text(encoding="utf-8").strip() | |
| except Exception as e: | |
| logger.warning(f"Failed to read upload marker {marker_file}: {e}") | |
| return None | |