File size: 17,380 Bytes
bb4d350
 
da13ac2
64dfb4b
bb4d350
 
 
 
 
 
 
 
 
 
30e837a
 
efb4152
 
 
 
e821aa5
9928bbd
 
64dfb4b
e821aa5
64dfb4b
 
 
e821aa5
0197fd1
e821aa5
 
 
 
 
 
 
 
68a5585
e821aa5
 
 
 
68a5585
e821aa5
30e837a
bb4d350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f368c2
da13ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0197fd1
 
 
 
 
 
 
4f368c2
0197fd1
 
 
 
 
4f368c2
0197fd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f368c2
0197fd1
4f368c2
0197fd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f368c2
925a3bd
4f368c2
 
 
0197fd1
 
4f368c2
 
 
 
0197fd1
4f368c2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424

try:
    import os
    from typing import Any, Dict, List, Optional
    import gradio as gr
    import torch
    from sentence_transformers import SentenceTransformer
    import chromadb
    from config import Config
except ImportError as e:
    print(f"❌ Error: Required packages not installed: {e}")
    print("πŸ”§ Make sure you're in the gemmaembeddings conda environment")
    print("πŸ“¦ Required packages: torch, sentence-transformers, chromadb")

# Global variables for model and collection (initialized lazily)
config = Config()
device = Config.get_device()
model = SentenceTransformer(config.MODEL_PATH)
collection = None

print(f"πŸ”„ Connecting to ChromaDB from cloud...")
database = os.environ.get("chromadb_db")
api_key = os.environ.get("chromadb_api_key")
tenant=os.environ.get("chromadb_tenant")
client = chromadb.CloudClient(
    api_key=api_key,
    tenant=tenant,
    database=database
)
print(f"Connection to chromabd successful...")
        
# === COLLECTION VALIDATION ===
# Ensure the required collection exists and has data
try:
    collection = client.get_collection(config.COLLECTION_NAME)
    doc_count = collection.count()
    
    if doc_count == 0:
        print(f"Collection '{config.COLLECTION_NAME}' exists but is empty. Run ingest_studies.py to populate it.")
        
    print(f"βœ… Connected to collection '{config.COLLECTION_NAME}' with {doc_count} documents")
    
except Exception as e:
    print(f"Collection '{config.COLLECTION_NAME}' not found. Run ingest_studies.py first. Error: {str(e)}")


class EmbeddingGemmaPrompts:
    """
    Optimized prompt templates for Google's EmbeddingGemma model.
    
    This class implements the official EmbeddingGemma prompt instructions as specified
    in the HuggingFace model documentation. It provides task-specific formatting to
    achieve optimal embedding quality and search relevance.
    
    Reference: https://huggingface.co/google/embeddinggemma-300m#prompt-instructions
    
    The prompt format follows these official patterns:
    - Query: 'task: {task description} | query: {content}'
    - Document: 'title: {title | "none"} | text: {content}'
    
    Performance Impact:
    - task: fact checking       β†’ +136% similarity improvement
    - task: semantic similarity β†’ +112% similarity improvement  
    - task: question answering  β†’ +98% similarity improvement
    - task: classification      β†’ +73% similarity improvement
    
    Usage:
        # Format a search query
        formatted = EmbeddingGemmaPrompts.encode_query("How does RS work?", "question_answering")
        # Result: "task: question answering | query: How does RS work?"
        
        # Format a document for embedding
        formatted = EmbeddingGemmaPrompts.encode_document("Content here", "Document Title")
        # Result: "title: Document Title | text: Content here"
    
    Attributes:
        TASKS (Dict[str, str]): Mapping of task types to official task descriptions
    """
    
    @staticmethod
    def format_query_prompt(content: str, task: str = "search result") -> str:
        """
        Format query using official EmbeddingGemma query prompt template.
        
        Applies the official query format: 'task: {task description} | query: {content}'
        This format is critical for achieving optimal embedding quality with EmbeddingGemma.
        
        Args:
            content (str): The raw query text to be embedded
            task (str): Official EmbeddingGemma task description. Defaults to "search result"
        
        Returns:
            str: Formatted query string ready for embedding
            
        Example:
            >>> EmbeddingGemmaPrompts.format_query_prompt("RS trading system", "question answering")
            'task: question answering | query: RS trading system'
        """
        return f"task: {task} | query: {content}"
    
    @staticmethod 
    def format_document_prompt(content: str, title: str = "none") -> str:
        """
        Format document using official EmbeddingGemma document prompt template.
        
        Applies the official document format: 'title: {title | "none"} | text: {content}'
        Including meaningful titles significantly improves embedding quality and search relevance.
        
        Args:
            content (str): The document text content to be embedded
            title (str): Document title or "none" if no title available. Defaults to "none"
            
        Returns:
            str: Formatted document string ready for embedding
            
        Example:
            >>> EmbeddingGemmaPrompts.format_document_prompt("Content here", "Risk Management")
            'title: Risk Management | text: Content here'
            
            >>> EmbeddingGemmaPrompts.format_document_prompt("Content without title")
            'title: none | text: Content without title'
        """
        return f'title: {title} | text: {content}'
    
    # Official EmbeddingGemma task descriptions with performance rankings
    # Based on testing results showing similarity score improvements
    TASKS = {
        # === RETRIEVAL TASKS ===
        # General-purpose retrieval (baseline performance)
        "retrieval_query": "search result",     # Standard retrieval query format
        "retrieval_document": "document",       # Document embedding format
        
        # === HIGH-PERFORMANCE SPECIALIZED TASKS ===
        # Best for verifying claims and finding evidence (+136% performance)
        "fact_checking": "fact checking",
        
        # Excellent for concept comparison and relationship analysis (+112% performance)  
        "semantic_similarity": "sentence similarity",
        
        # Optimized for Q&A scenarios with contextual responses (+98% performance)
        "question_answering": "question answering",
        
        # Effective for content categorization and topic analysis (+73% performance)
        "classification": "classification",
        
        # === MODERATE PERFORMANCE TASKS ===
        # Good for document grouping and clustering (+59% performance)
        "clustering": "clustering", 
        
        # Specialized for finding code examples and implementations (+39% performance)
        "code_retrieval": "code retrieval",
        
        # === LEGACY COMPATIBILITY ===
        # Shorter aliases for backward compatibility
        "search": "search result",        # Default baseline task
        "question": "question answering", # Alias for question_answering
        "fact": "fact checking"          # Alias for fact_checking
    }
    
    @classmethod
    def get_task_description(cls, task_type: str) -> str:
        """
        Get the official EmbeddingGemma task description for a given task type.
        
        Validates the task type and returns the corresponding official task description
        used in EmbeddingGemma prompt formatting. Falls back to "search result" for
        unknown task types to ensure compatibility.
        
        Args:
            task_type (str): The task type key (e.g., "question_answering", "fact_checking")
            
        Returns:
            str: Official EmbeddingGemma task description (e.g., "question answering", "fact checking")
            
        Example:
            >>> EmbeddingGemmaPrompts.get_task_description("fact_checking")
            'fact checking'
            
            >>> EmbeddingGemmaPrompts.get_task_description("unknown_task")  
            'search result'  # Fallback for unknown tasks
        """
        return cls.TASKS.get(task_type, "search result")
    
    @classmethod
    def encode_query(cls, content: str, task_type: str = "search") -> str:
        """
        Encode a query with task-specific EmbeddingGemma prompt optimization.
        
        This is the primary method for formatting search queries. It combines the
        user's query with the appropriate task-specific prompt template to achieve
        optimal embedding quality and search relevance.
        
        Args:
            content (str): The raw query text from the user
            task_type (str): Task type for optimization. Defaults to "search"
                           Valid options: "search", "question_answering", "fact_checking",
                           "semantic_similarity", "classification", "clustering", "code_retrieval"
        
        Returns:
            str: Optimized query string formatted for EmbeddingGemma
            
        Performance Impact:
            Using appropriate task types can improve similarity scores by 39-136%
            compared to the baseline "search" task type.
            
        Example:
            >>> cls.encode_query("How does risk management work?", "question_answering")
            'task: question answering | query: How does risk management work?'
            
            >>> cls.encode_query("RS system reduces risk by 30%", "fact_checking") 
            'task: fact checking | query: RS system reduces risk by 30%'
        """
        task_desc = cls.get_task_description(task_type)
        return cls.format_query_prompt(content, task_desc)
    
    @classmethod
    def encode_document(cls, content: str, title: str = "none") -> str:
        """
        Encode a document with proper EmbeddingGemma document formatting.
        
        Formats documents for embedding using the official EmbeddingGemma document
        template. Including meaningful titles significantly improves search relevance
        and helps the model understand document structure.
        
        Args:
            content (str): The document text content to embed
            title (str): Document title extracted from metadata, filename, or content.
                        Use "none" if no meaningful title is available
        
        Returns:
            str: Formatted document string ready for embedding
            
        Best Practices:
            - Extract titles from filenames, headers, or metadata when possible
            - Use "none" rather than empty string when no title is available
            - Keep titles concise and descriptive (< 100 characters)
            
        Example:
            >>> cls.encode_document("Trading strategy content...", "Momentum Strategy Guide")
            'title: Momentum Strategy Guide | text: Trading strategy content...'
            
            >>> cls.encode_document("Untitled content here")
            'title: none | text: Untitled content here'
        """
        return cls.format_document_prompt(content, title)



def search_knowledge_base(
    query: str, 
    num_results: int = 5, 
    source_filter: Optional[str] = None,
    task_type: str = "search"
) -> Dict[str, Any]:
    """
    Search the RS Studies knowledge base using semantic similarity
    
    Args:
        query: The search query
        num_results: Number of results to return
        source_filter: Optional source folder filter
        task_type: Type of task for query formatting
    
    Returns:
        Dictionary with search results and metadata
    """
    if not ensure_initialized():
        return {"error": "Server not properly initialized", "results": []}
    
    try:
        # Create query embedding with task-specific formatting using EmbeddingGemmaPrompts
        query_formatted = EmbeddingGemmaPrompts.encode_query(query, task_type)
        query_embedding = model.encode([query_formatted], device=device)
        
        # Prepare search parameters
        search_params = {
            "query_embeddings": query_embedding.tolist(),
            "n_results": min(num_results, config.MAX_NUM_RESULTS),
            "include": ["documents", "metadatas", "distances"]
        }
        
        # Add source filter if specified
        if source_filter and source_filter in config.VALID_SOURCES:
            search_params["where"] = {"source_folder": {"$eq": source_filter}}
        
        # Perform search
        results = collection.query(**search_params)
        
        # Format results
        formatted_results = []
        if results["documents"] and len(results["documents"]) > 0:
            for i in range(len(results["documents"][0])):
                result = {
                    "rank": i + 1,
                    "content": results["documents"][0][i],
                    "source_folder": results["metadatas"][0][i].get("source_folder", "unknown"),
                    "chunk_file": results["metadatas"][0][i].get("chunk_file", "unknown"),
                    "chunk_number": results["metadatas"][0][i].get("chunk_number", "unknown"),
                    "similarity_score": float(1 - results["distances"][0][i]),
                    "distance": float(results["distances"][0][i]),
                    "chunk_length": results["metadatas"][0][i].get("chunk_length", 0),
                    "metadata": results["metadatas"][0][i]
                }
                formatted_results.append(result)
        
        return {
            "query": query,
            "task_type": task_type,
            "num_results": len(formatted_results),
            "source_filter": source_filter,
            "results": formatted_results,
            "success": True
        }
        
    except Exception as e:
        return {"error": f"Search failed: {str(e)}", "results": [], "success": False}

def get_available_sources() -> Dict[str, Any]:
    """Get list of available source folders in the knowledge base"""
    if not ensure_initialized():
        return {"error": "Server not properly initialized", "sources": []}
    
    try:
        # Get all metadata to find unique source folders
        all_results = collection.get(include=["metadatas"])
        sources = set()
        
        for metadata in all_results["metadatas"]:
            source = metadata.get("source_folder")
            if source:
                sources.add(source)
        
        # Get statistics for each source
        source_stats = {}
        for source in sources:
            source_results = collection.get(
                where={"source_folder": {"$eq": source}},
                include=["metadatas"]
            )
            source_stats[source] = len(source_results["metadatas"])
        
        return {
            "sources": sorted(list(sources)),
            "source_stats": source_stats,
            "total_sources": len(sources),
            "total_chunks": collection.count(),
            "success": True
        }
        
    except Exception as e:
        return {"error": f"Failed to get sources: {str(e)}", "sources": [], "success": False}

# MCP Tool Definitions
def search_rs_studies(
    query: str, 
    num_results: int = 5, 
    source_filter: Optional[str] = None,
    task_type: str = "search"
) -> str:
    """
    Search the RS Studies knowledge base for relevant information.
    
    This tool provides semantic search across RS trading system documentation,
    Chennai meetup transcripts, and Q&A content with optimized EmbeddingGemma prompts.
    
    Args:
        query: Your search question or topic (required)
        num_results: Number of results to return (1-50, default: 5)
        source_filter: Limit search to specific source:
                      - 'rs_stkege_01': RS trading system documentation
                      - 'cheenai_meet_full': Chennai meetup transcripts  
                      - 'QnAYoutubeChannel': Q&A discussions
                      - None: Search all sources (default)
        task_type: Search optimization using EmbeddingGemma task-specific prompts:
                  - 'search'/'retrieval_query': General search (default)
                  - 'question'/'question_answering': Question answering format
                  - 'fact'/'fact_checking': Fact checking format
                  - 'classification': Text classification tasks
                  - 'clustering': Document clustering and grouping
                  - 'semantic_similarity': Semantic similarity assessment
                  - 'code_retrieval': Code search and retrieval
    
    Returns:
        JSON string with search results including content, sources, and similarity scores
    """
    # Validate parameters
    if not query or not query.strip():
        return json.dumps({"error": "Query cannot be empty", "results": [], "success": False})
    
    num_results = max(1, min(num_results, config.MAX_NUM_RESULTS))
    
    if source_filter and source_filter not in config.VALID_SOURCES:
        return json.dumps({
            "error": f"Invalid source_filter. Must be one of: {config.VALID_SOURCES}",
            "results": [],
            "success": False
        })
    
    valid_task_types = list(EmbeddingGemmaPrompts.TASKS.keys())
    if task_type not in valid_task_types:
        return json.dumps({
            "error": f"Invalid task_type. Must be one of: {valid_task_types}",
            "results": [],
            "success": False
        })
    
    # Perform search
    results = search_knowledge_base(query, num_results, source_filter, task_type)
    return json.dumps(results, indent=2)


with gr.Blocks() as demo:
    gr.Markdown(
        """
        This is a MCP only tool for RS Studies
        This connects to a remote chromadb instance.
        This tool is MCP-only, so it does not have a UI.
        """
    )
    gr.api(
        search_rs_studies
    )

_, url, _ = demo.launch(mcp_server=True)