|
|
import os |
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained("antonypamo/ProSavantRRF", token=hf_token) |
|
|
self.model = AutoModelForCausalLM.from_pretrained("antonypamo/ProSavantRRF", token=hf_token) |
|
|
import pickle |
|
|
import faiss |
|
|
import numpy as np |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SavantRRFEngine: |
|
|
def __init__(self): |
|
|
self.assets_path = os.path.join(os.path.dirname(__file__), "assets") |
|
|
print("πΉ Loading Savant-RRF memory...") |
|
|
self.load_memory() |
|
|
print("πΉ Loading Savant-RRF model...") |
|
|
self.load_model() |
|
|
|
|
|
def load_memory(self): |
|
|
index_path = os.path.join(self.assets_path, "memory.index") |
|
|
memory_path = os.path.join(self.assets_path, "persistent_memory.pkl") |
|
|
|
|
|
if not os.path.exists(index_path): |
|
|
raise FileNotFoundError(f"β Missing FAISS index: {index_path}") |
|
|
if not os.path.exists(memory_path): |
|
|
raise FileNotFoundError(f"β Missing memory data: {memory_path}") |
|
|
|
|
|
self.index = faiss.read_index(index_path) |
|
|
with open(memory_path, "rb") as f: |
|
|
self.memory = pickle.load(f) |
|
|
print(f"β
Memory bank loaded: {len(self.memory)} entries") |
|
|
|
|
|
def load_model(self): |
|
|
try: |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("antonypamo/ProSavantRRF") |
|
|
self.model = AutoModelForCausalLM.from_pretrained("antonypamo/ProSavantRRF") |
|
|
print("β
Model loaded from HF Hub.") |
|
|
except Exception as e: |
|
|
print(f"β οΈ HF load failed: {e}\nπ Falling back to local model...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.assets_path) |
|
|
self.model = AutoModelForCausalLM.from_pretrained(self.assets_path) |
|
|
print("β
Local model loaded.") |
|
|
|
|
|
self.pipe = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer) |
|
|
|
|
|
def recall_memory(self, query, top_k=5): |
|
|
from sentence_transformers import SentenceTransformer |
|
|
encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
|
|
q_emb = encoder.encode([query]) |
|
|
D, I = self.index.search(np.array(q_emb).astype("float32"), top_k) |
|
|
return [self.memory[i][0] for i in I[0]] |
|
|
|
|
|
def infer(self, prompt): |
|
|
retrieved = self.recall_memory(prompt, top_k=3) |
|
|
context = "\n".join(retrieved) |
|
|
full_prompt = f"Context:\n{context}\n\nUser: {prompt}\nSavant-RRF:" |
|
|
result = self.pipe(full_prompt, max_new_tokens=150, do_sample=True, temperature=0.7) |
|
|
return result[0]["generated_text"] |
|
|
|
|
|
|