Spaces:
Running
Running
File size: 17,380 Bytes
bb4d350 da13ac2 64dfb4b bb4d350 30e837a efb4152 e821aa5 9928bbd 64dfb4b e821aa5 64dfb4b e821aa5 0197fd1 e821aa5 68a5585 e821aa5 68a5585 e821aa5 30e837a bb4d350 4f368c2 da13ac2 0197fd1 4f368c2 0197fd1 4f368c2 0197fd1 4f368c2 0197fd1 4f368c2 0197fd1 4f368c2 925a3bd 4f368c2 0197fd1 4f368c2 0197fd1 4f368c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 |
try:
import os
from typing import Any, Dict, List, Optional
import gradio as gr
import torch
from sentence_transformers import SentenceTransformer
import chromadb
from config import Config
except ImportError as e:
print(f"β Error: Required packages not installed: {e}")
print("π§ Make sure you're in the gemmaembeddings conda environment")
print("π¦ Required packages: torch, sentence-transformers, chromadb")
# Global variables for model and collection (initialized lazily)
config = Config()
device = Config.get_device()
model = SentenceTransformer(config.MODEL_PATH)
collection = None
print(f"π Connecting to ChromaDB from cloud...")
database = os.environ.get("chromadb_db")
api_key = os.environ.get("chromadb_api_key")
tenant=os.environ.get("chromadb_tenant")
client = chromadb.CloudClient(
api_key=api_key,
tenant=tenant,
database=database
)
print(f"Connection to chromabd successful...")
# === COLLECTION VALIDATION ===
# Ensure the required collection exists and has data
try:
collection = client.get_collection(config.COLLECTION_NAME)
doc_count = collection.count()
if doc_count == 0:
print(f"Collection '{config.COLLECTION_NAME}' exists but is empty. Run ingest_studies.py to populate it.")
print(f"β
Connected to collection '{config.COLLECTION_NAME}' with {doc_count} documents")
except Exception as e:
print(f"Collection '{config.COLLECTION_NAME}' not found. Run ingest_studies.py first. Error: {str(e)}")
class EmbeddingGemmaPrompts:
"""
Optimized prompt templates for Google's EmbeddingGemma model.
This class implements the official EmbeddingGemma prompt instructions as specified
in the HuggingFace model documentation. It provides task-specific formatting to
achieve optimal embedding quality and search relevance.
Reference: https://huggingface.co/google/embeddinggemma-300m#prompt-instructions
The prompt format follows these official patterns:
- Query: 'task: {task description} | query: {content}'
- Document: 'title: {title | "none"} | text: {content}'
Performance Impact:
- task: fact checking β +136% similarity improvement
- task: semantic similarity β +112% similarity improvement
- task: question answering β +98% similarity improvement
- task: classification β +73% similarity improvement
Usage:
# Format a search query
formatted = EmbeddingGemmaPrompts.encode_query("How does RS work?", "question_answering")
# Result: "task: question answering | query: How does RS work?"
# Format a document for embedding
formatted = EmbeddingGemmaPrompts.encode_document("Content here", "Document Title")
# Result: "title: Document Title | text: Content here"
Attributes:
TASKS (Dict[str, str]): Mapping of task types to official task descriptions
"""
@staticmethod
def format_query_prompt(content: str, task: str = "search result") -> str:
"""
Format query using official EmbeddingGemma query prompt template.
Applies the official query format: 'task: {task description} | query: {content}'
This format is critical for achieving optimal embedding quality with EmbeddingGemma.
Args:
content (str): The raw query text to be embedded
task (str): Official EmbeddingGemma task description. Defaults to "search result"
Returns:
str: Formatted query string ready for embedding
Example:
>>> EmbeddingGemmaPrompts.format_query_prompt("RS trading system", "question answering")
'task: question answering | query: RS trading system'
"""
return f"task: {task} | query: {content}"
@staticmethod
def format_document_prompt(content: str, title: str = "none") -> str:
"""
Format document using official EmbeddingGemma document prompt template.
Applies the official document format: 'title: {title | "none"} | text: {content}'
Including meaningful titles significantly improves embedding quality and search relevance.
Args:
content (str): The document text content to be embedded
title (str): Document title or "none" if no title available. Defaults to "none"
Returns:
str: Formatted document string ready for embedding
Example:
>>> EmbeddingGemmaPrompts.format_document_prompt("Content here", "Risk Management")
'title: Risk Management | text: Content here'
>>> EmbeddingGemmaPrompts.format_document_prompt("Content without title")
'title: none | text: Content without title'
"""
return f'title: {title} | text: {content}'
# Official EmbeddingGemma task descriptions with performance rankings
# Based on testing results showing similarity score improvements
TASKS = {
# === RETRIEVAL TASKS ===
# General-purpose retrieval (baseline performance)
"retrieval_query": "search result", # Standard retrieval query format
"retrieval_document": "document", # Document embedding format
# === HIGH-PERFORMANCE SPECIALIZED TASKS ===
# Best for verifying claims and finding evidence (+136% performance)
"fact_checking": "fact checking",
# Excellent for concept comparison and relationship analysis (+112% performance)
"semantic_similarity": "sentence similarity",
# Optimized for Q&A scenarios with contextual responses (+98% performance)
"question_answering": "question answering",
# Effective for content categorization and topic analysis (+73% performance)
"classification": "classification",
# === MODERATE PERFORMANCE TASKS ===
# Good for document grouping and clustering (+59% performance)
"clustering": "clustering",
# Specialized for finding code examples and implementations (+39% performance)
"code_retrieval": "code retrieval",
# === LEGACY COMPATIBILITY ===
# Shorter aliases for backward compatibility
"search": "search result", # Default baseline task
"question": "question answering", # Alias for question_answering
"fact": "fact checking" # Alias for fact_checking
}
@classmethod
def get_task_description(cls, task_type: str) -> str:
"""
Get the official EmbeddingGemma task description for a given task type.
Validates the task type and returns the corresponding official task description
used in EmbeddingGemma prompt formatting. Falls back to "search result" for
unknown task types to ensure compatibility.
Args:
task_type (str): The task type key (e.g., "question_answering", "fact_checking")
Returns:
str: Official EmbeddingGemma task description (e.g., "question answering", "fact checking")
Example:
>>> EmbeddingGemmaPrompts.get_task_description("fact_checking")
'fact checking'
>>> EmbeddingGemmaPrompts.get_task_description("unknown_task")
'search result' # Fallback for unknown tasks
"""
return cls.TASKS.get(task_type, "search result")
@classmethod
def encode_query(cls, content: str, task_type: str = "search") -> str:
"""
Encode a query with task-specific EmbeddingGemma prompt optimization.
This is the primary method for formatting search queries. It combines the
user's query with the appropriate task-specific prompt template to achieve
optimal embedding quality and search relevance.
Args:
content (str): The raw query text from the user
task_type (str): Task type for optimization. Defaults to "search"
Valid options: "search", "question_answering", "fact_checking",
"semantic_similarity", "classification", "clustering", "code_retrieval"
Returns:
str: Optimized query string formatted for EmbeddingGemma
Performance Impact:
Using appropriate task types can improve similarity scores by 39-136%
compared to the baseline "search" task type.
Example:
>>> cls.encode_query("How does risk management work?", "question_answering")
'task: question answering | query: How does risk management work?'
>>> cls.encode_query("RS system reduces risk by 30%", "fact_checking")
'task: fact checking | query: RS system reduces risk by 30%'
"""
task_desc = cls.get_task_description(task_type)
return cls.format_query_prompt(content, task_desc)
@classmethod
def encode_document(cls, content: str, title: str = "none") -> str:
"""
Encode a document with proper EmbeddingGemma document formatting.
Formats documents for embedding using the official EmbeddingGemma document
template. Including meaningful titles significantly improves search relevance
and helps the model understand document structure.
Args:
content (str): The document text content to embed
title (str): Document title extracted from metadata, filename, or content.
Use "none" if no meaningful title is available
Returns:
str: Formatted document string ready for embedding
Best Practices:
- Extract titles from filenames, headers, or metadata when possible
- Use "none" rather than empty string when no title is available
- Keep titles concise and descriptive (< 100 characters)
Example:
>>> cls.encode_document("Trading strategy content...", "Momentum Strategy Guide")
'title: Momentum Strategy Guide | text: Trading strategy content...'
>>> cls.encode_document("Untitled content here")
'title: none | text: Untitled content here'
"""
return cls.format_document_prompt(content, title)
def search_knowledge_base(
query: str,
num_results: int = 5,
source_filter: Optional[str] = None,
task_type: str = "search"
) -> Dict[str, Any]:
"""
Search the RS Studies knowledge base using semantic similarity
Args:
query: The search query
num_results: Number of results to return
source_filter: Optional source folder filter
task_type: Type of task for query formatting
Returns:
Dictionary with search results and metadata
"""
if not ensure_initialized():
return {"error": "Server not properly initialized", "results": []}
try:
# Create query embedding with task-specific formatting using EmbeddingGemmaPrompts
query_formatted = EmbeddingGemmaPrompts.encode_query(query, task_type)
query_embedding = model.encode([query_formatted], device=device)
# Prepare search parameters
search_params = {
"query_embeddings": query_embedding.tolist(),
"n_results": min(num_results, config.MAX_NUM_RESULTS),
"include": ["documents", "metadatas", "distances"]
}
# Add source filter if specified
if source_filter and source_filter in config.VALID_SOURCES:
search_params["where"] = {"source_folder": {"$eq": source_filter}}
# Perform search
results = collection.query(**search_params)
# Format results
formatted_results = []
if results["documents"] and len(results["documents"]) > 0:
for i in range(len(results["documents"][0])):
result = {
"rank": i + 1,
"content": results["documents"][0][i],
"source_folder": results["metadatas"][0][i].get("source_folder", "unknown"),
"chunk_file": results["metadatas"][0][i].get("chunk_file", "unknown"),
"chunk_number": results["metadatas"][0][i].get("chunk_number", "unknown"),
"similarity_score": float(1 - results["distances"][0][i]),
"distance": float(results["distances"][0][i]),
"chunk_length": results["metadatas"][0][i].get("chunk_length", 0),
"metadata": results["metadatas"][0][i]
}
formatted_results.append(result)
return {
"query": query,
"task_type": task_type,
"num_results": len(formatted_results),
"source_filter": source_filter,
"results": formatted_results,
"success": True
}
except Exception as e:
return {"error": f"Search failed: {str(e)}", "results": [], "success": False}
def get_available_sources() -> Dict[str, Any]:
"""Get list of available source folders in the knowledge base"""
if not ensure_initialized():
return {"error": "Server not properly initialized", "sources": []}
try:
# Get all metadata to find unique source folders
all_results = collection.get(include=["metadatas"])
sources = set()
for metadata in all_results["metadatas"]:
source = metadata.get("source_folder")
if source:
sources.add(source)
# Get statistics for each source
source_stats = {}
for source in sources:
source_results = collection.get(
where={"source_folder": {"$eq": source}},
include=["metadatas"]
)
source_stats[source] = len(source_results["metadatas"])
return {
"sources": sorted(list(sources)),
"source_stats": source_stats,
"total_sources": len(sources),
"total_chunks": collection.count(),
"success": True
}
except Exception as e:
return {"error": f"Failed to get sources: {str(e)}", "sources": [], "success": False}
# MCP Tool Definitions
def search_rs_studies(
query: str,
num_results: int = 5,
source_filter: Optional[str] = None,
task_type: str = "search"
) -> str:
"""
Search the RS Studies knowledge base for relevant information.
This tool provides semantic search across RS trading system documentation,
Chennai meetup transcripts, and Q&A content with optimized EmbeddingGemma prompts.
Args:
query: Your search question or topic (required)
num_results: Number of results to return (1-50, default: 5)
source_filter: Limit search to specific source:
- 'rs_stkege_01': RS trading system documentation
- 'cheenai_meet_full': Chennai meetup transcripts
- 'QnAYoutubeChannel': Q&A discussions
- None: Search all sources (default)
task_type: Search optimization using EmbeddingGemma task-specific prompts:
- 'search'/'retrieval_query': General search (default)
- 'question'/'question_answering': Question answering format
- 'fact'/'fact_checking': Fact checking format
- 'classification': Text classification tasks
- 'clustering': Document clustering and grouping
- 'semantic_similarity': Semantic similarity assessment
- 'code_retrieval': Code search and retrieval
Returns:
JSON string with search results including content, sources, and similarity scores
"""
# Validate parameters
if not query or not query.strip():
return json.dumps({"error": "Query cannot be empty", "results": [], "success": False})
num_results = max(1, min(num_results, config.MAX_NUM_RESULTS))
if source_filter and source_filter not in config.VALID_SOURCES:
return json.dumps({
"error": f"Invalid source_filter. Must be one of: {config.VALID_SOURCES}",
"results": [],
"success": False
})
valid_task_types = list(EmbeddingGemmaPrompts.TASKS.keys())
if task_type not in valid_task_types:
return json.dumps({
"error": f"Invalid task_type. Must be one of: {valid_task_types}",
"results": [],
"success": False
})
# Perform search
results = search_knowledge_base(query, num_results, source_filter, task_type)
return json.dumps(results, indent=2)
with gr.Blocks() as demo:
gr.Markdown(
"""
This is a MCP only tool for RS Studies
This connects to a remote chromadb instance.
This tool is MCP-only, so it does not have a UI.
"""
)
gr.api(
search_rs_studies
)
_, url, _ = demo.launch(mcp_server=True) |