File size: 1,829 Bytes
e6471bf
8f3f3ca
e6471bf
 
 
 
8f3f3ca
e6471bf
 
8f3f3ca
e6471bf
 
 
 
 
 
 
8f3f3ca
 
 
e6471bf
 
 
 
 
 
 
 
8f3f3ca
 
 
 
e6471bf
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Matrix-BIOS-Memory-0.1 — grounded, citation-faithful recall (RAG).
Ships a FAISS index + a small corpus; every answer cites the source ids it used.
pip install torch transformers sentence-transformers faiss-cpu huggingface_hub
"""
import json
import faiss
import torch
from huggingface_hub import snapshot_download
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

REPO = "ruslanmv/Matrix-BIOS-Memory-0.1"
path = snapshot_download(REPO)
cfg  = json.load(open(f"{path}/memory_config.json"))   # embedder / generator / top_k
docs = json.load(open(f"{path}/docs.json"))            # [{"id": ..., "text": ...}]
index = faiss.read_index(f"{path}/index.faiss")

embedder = SentenceTransformer(cfg["embedder"])
gen_tok  = AutoTokenizer.from_pretrained(cfg["generator"])
gen_model = AutoModelForSeq2SeqLM.from_pretrained(cfg["generator"]).eval()

def answer(question: str):
    qv = embedder.encode([question], normalize_embeddings=True).astype("float32")
    _, idx = index.search(qv, cfg["top_k"])
    hits = [docs[i] for i in idx[0] if 0 <= i < len(docs)]
    context = "\n".join(f"[{d['id']}] {d['text']}" for d in hits)
    prompt = ("Answer the question using ONLY the context, and cite the [id] you used.\n"
              f"Context:\n{context}\n\nQuestion: {question}\nAnswer:")
    ids = gen_tok(prompt, return_tensors="pt", truncation=True).input_ids
    with torch.no_grad():
        out = gen_model.generate(ids, max_new_tokens=64)
    return gen_tok.decode(out[0], skip_special_tokens=True), [d["id"] for d in hits]

if __name__ == "__main__":
    for q in ["What does every effectful action in Matrix OS emit?",
              "Qual e la capitale d'Italia?"]:
        ans, sources = answer(q)
        print(f"Q: {q}\nA: {ans}\n   sources: {sources}\n")