ProSavantRRF / model.py
antonypamo's picture
Upload model.py
50e8558 verified
import os
hf_token = os.getenv("HF_TOKEN")
self.tokenizer = AutoTokenizer.from_pretrained("antonypamo/ProSavantRRF", token=hf_token)
self.model = AutoModelForCausalLM.from_pretrained("antonypamo/ProSavantRRF", token=hf_token)
import pickle
import faiss
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
class SavantRRFEngine:
def __init__(self):
self.assets_path = os.path.join(os.path.dirname(__file__), "assets")
print("πŸ”Ή Loading Savant-RRF memory...")
self.load_memory()
print("πŸ”Ή Loading Savant-RRF model...")
self.load_model()
def load_memory(self):
index_path = os.path.join(self.assets_path, "memory.index")
memory_path = os.path.join(self.assets_path, "persistent_memory.pkl")
if not os.path.exists(index_path):
raise FileNotFoundError(f"❌ Missing FAISS index: {index_path}")
if not os.path.exists(memory_path):
raise FileNotFoundError(f"❌ Missing memory data: {memory_path}")
self.index = faiss.read_index(index_path)
with open(memory_path, "rb") as f:
self.memory = pickle.load(f)
print(f"βœ… Memory bank loaded: {len(self.memory)} entries")
def load_model(self):
try:
# Primary: load from Hugging Face Hub
self.tokenizer = AutoTokenizer.from_pretrained("antonypamo/ProSavantRRF")
self.model = AutoModelForCausalLM.from_pretrained("antonypamo/ProSavantRRF")
print("βœ… Model loaded from HF Hub.")
except Exception as e:
print(f"⚠️ HF load failed: {e}\nπŸ” Falling back to local model...")
self.tokenizer = AutoTokenizer.from_pretrained(self.assets_path)
self.model = AutoModelForCausalLM.from_pretrained(self.assets_path)
print("βœ… Local model loaded.")
self.pipe = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
def recall_memory(self, query, top_k=5):
from sentence_transformers import SentenceTransformer
encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
q_emb = encoder.encode([query])
D, I = self.index.search(np.array(q_emb).astype("float32"), top_k)
return [self.memory[i][0] for i in I[0]]
def infer(self, prompt):
retrieved = self.recall_memory(prompt, top_k=3)
context = "\n".join(retrieved)
full_prompt = f"Context:\n{context}\n\nUser: {prompt}\nSavant-RRF:"
result = self.pipe(full_prompt, max_new_tokens=150, do_sample=True, temperature=0.7)
return result[0]["generated_text"]