| | import logging |
| | import os |
| | from dataclasses import dataclass |
| | from enum import Enum |
| | from typing import Any, Dict, List, Optional |
| |
|
| | import mteb |
| | from sqlitedict import SqliteDict |
| |
|
| | from pylate import indexes, models, retrieve |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
| | ) |
| |
|
| |
|
| | class IndexType(Enum): |
| | """Supported index types.""" |
| |
|
| | PREBUILT = "prebuilt" |
| | LOCAL = "local" |
| |
|
| |
|
| | @dataclass |
| | class IndexConfig: |
| | """Configuration for a search index.""" |
| |
|
| | name: str |
| | type: IndexType |
| | path: str |
| | description: Optional[str] = None |
| |
|
| |
|
| | class MCPyLate: |
| | """Main server class that manages PyLate indexes and search operations.""" |
| |
|
| | def __init__(self, override: bool = False): |
| | self.logger = logging.getLogger(__name__) |
| | dataset_name = "leetcode" |
| |
|
| | model_name = "lightonai/Reason-ModernColBERT" |
| | override = override or not os.path.exists( |
| | f"indexes/{dataset_name}_{model_name.split('/')[-1]}" |
| | ) |
| |
|
| | self.model = models.ColBERT( |
| | model_name_or_path=model_name, |
| | ) |
| | self.index = indexes.PLAID( |
| | override=override, |
| | index_name=f"{dataset_name}_{model_name.split('/')[-1]}", |
| | ) |
| | self.id_to_doc = SqliteDict( |
| | f"./indexes/{dataset_name}_{model_name.split('/')[-1]}/id_to_doc.sqlite", |
| | outer_stack=False, |
| | ) |
| |
|
| | self.retriever = retrieve.ColBERT(index=self.index) |
| | if override: |
| | tasks = mteb.get_tasks(tasks=["BrightRetrieval"]) |
| | tasks[0].load_data() |
| | for doc, doc_id in zip( |
| | list(tasks[0].corpus[dataset_name]["standard"].values()), |
| | list(tasks[0].corpus[dataset_name]["standard"].keys()), |
| | ): |
| | self.id_to_doc[doc_id] = doc |
| | self.id_to_doc.commit() |
| | documents_embeddings = self.model.encode( |
| | sentences=list(tasks[0].corpus[dataset_name]["standard"].values()), |
| | batch_size=100, |
| | is_query=False, |
| | show_progress_bar=True, |
| | ) |
| |
|
| | self.index.add_documents( |
| | documents_ids=list(tasks[0].corpus[dataset_name]["standard"].keys()), |
| | documents_embeddings=documents_embeddings, |
| | ) |
| | self.logger.info("Created PyLate MCP Server") |
| |
|
| | def get_document( |
| | self, |
| | docid: str, |
| | ) -> Optional[Dict[str, Any]]: |
| | """Retrieve full document by document ID.""" |
| |
|
| | return {"docid": docid, "text": self.id_to_doc[docid]} |
| |
|
| | def search(self, query: str, k: int = 10) -> List[Dict[str, Any]]: |
| | """Perform multi-vector search on specified index.""" |
| | try: |
| | query_embeddings = self.model.encode( |
| | sentences=[query], |
| | is_query=True, |
| | show_progress_bar=True, |
| | batch_size=32, |
| | ) |
| | scores = self.retriever.retrieve(queries_embeddings=query_embeddings, k=20) |
| | results = [] |
| | for score in scores[0]: |
| | results.append( |
| | { |
| | "docid": score["id"], |
| | "score": round(score["score"], 5), |
| | "text": self.id_to_doc[score["id"]], |
| | |
| | |
| | |
| | } |
| | ) |
| | return results |
| | except Exception as e: |
| | self.logger.error(f"Search failed: {e}") |
| | raise RuntimeError(f"Search operation failed: {e}") |
| |
|