File size: 3,991 Bytes
23da55a
 
 
 
 
 
8b7c793
23da55a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf16341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import json
import pandas as pd
from huggingface_hub import HfApi, hf_hub_download, InferenceClient

HF_TOKEN = os.environ.get("HF_TOKEN")
REPO_ID = os.environ.get("HF_DATASET_ID", "Brettapps/brettapps-aussie-mcp-databank")

client = InferenceClient(
    provider="hf-inference",
    api_key=HF_TOKEN,
)

def get_embeddings(text):
    """Generate embeddings using the provided BART model for semantic search."""
    try:
        return client.feature_extraction(
            text,
            model="facebook/bart-base",
        )
    except Exception as e:
        print(f"Embedding error: {e}")
        return None

def save_to_databank(filename, content, folder="knowledge"):
    """Saves a file to the Hugging Face Dataset repository."""
    api = HfApi(token=HF_TOKEN)
    path_in_repo = f"{folder}/{filename}"
    
    # Write local temp file
    os.makedirs(folder, exist_ok=True)
    local_path = os.path.join(folder, filename)
    
    with open(local_path, "w") as f:
        if isinstance(content, (dict, list)):
            json.dump(content, f, indent=2)
        else:
            f.write(content)
            
    try:
        api.upload_file(
            path_or_fileobj=local_path,
            path_in_repo=path_in_repo,
            repo_id=REPO_ID,
            repo_type="dataset",
        )
        return True
    except Exception as e:
        print(f"Upload error: {e}")
        return False

def load_from_databank(filename, folder="knowledge"):
    """Loads a file from the Hugging Face Dataset repository."""
    try:
        local_path = hf_hub_download(
            repo_id=REPO_ID,
            filename=f"{folder}/{filename}",
            repo_type="dataset",
            token=HF_TOKEN
        )
        with open(local_path, "r") as f:
            if filename.endswith(".json"):
                return json.load(f)
            return f.read()
    except Exception as e:
        print(f"Download error: {e}")
        return None

class KnowledgeManager:
    def __init__(self, knowledge_dir="knowledge"):
        self.knowledge_dir = knowledge_dir
        self.index = {} # filename -> embedding
        self.initialized = False

    def initialize_index(self):
        """Build the semantic index for all local knowledge files."""
        if not os.path.exists(self.knowledge_dir):
            return
        
        for filename in os.listdir(self.knowledge_dir):
            if filename.endswith(".md"):
                path = os.path.join(self.knowledge_dir, filename)
                with open(path, "r") as f:
                    content = f.read()
                    # Use the first 500 chars for embedding to save time/resources
                    embedding = get_embeddings(content[:500])
                    if embedding is not None:
                        self.index[filename] = embedding
        self.initialized = True
        print(f"Knowledge index initialized with {len(self.index)} files.")

    def find_relevant_persona(self, query):
        """Find the most relevant persona file for a given query using cosine similarity."""
        if not self.initialized:
            self.initialize_index()
            
        query_embedding = get_embeddings(query)
        if query_embedding is None:
            return "router_instructions.md"
        
        best_file = "router_instructions.md"
        best_score = -1
        
        # Simple dot product for similarity (assuming normalized embeddings from BART)
        # Note: InferenceClient feature_extraction might not be normalized
        import numpy as np
        
        q_vec = np.array(query_embedding)
        
        for filename, f_vec in self.index.items():
            f_vec = np.array(f_vec)
            # Basic cosine similarity
            score = np.dot(q_vec, f_vec) / (np.linalg.norm(q_vec) * np.linalg.norm(f_vec))
            if score > best_score:
                best_score = score
                best_file = filename
        
        return best_file