File size: 3,302 Bytes
b00c961 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | import os
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict
from sentence_transformers import SentenceTransformer
import uvicorn
app = FastAPI()
# --- State ---
SHARD_ID = os.getenv("SHARD_ID", "shard_0")
TOPIC = os.getenv("SHARD_TOPIC", "General")
print(f"🐝 [Shard {SHARD_ID}] Initializing... Topic: {TOPIC}")
# Load Embedding Model (Shared or Local)
# In production, shards might just do vector math if embeddings are pre-computed,
# but for standalone robustness, we load the model.
embedder = SentenceTransformer('all-MiniLM-L6-v2')
# Mock Knowledge Base (In prod, load from Disk/DB)
KNOWLEDGE_BASE = []
def seed_knowledge():
print(f"🌱 Seeding Knowledge for {TOPIC}...")
topics_data = {
"Science": [
"The speed of light is 299,792,458 m/s.",
"Mitochondria is the powerhouse of the cell.",
"Water boils at 100 degrees Celsius at sea level."
],
"History": [
"The Roman Empire fell in 476 AD.",
"World War II ended in 1945.",
"The Great Wall of China was built over centuries."
],
"Coding": [
"Python uses indentation for blocks.",
"React uses a virtual DOM for performance.",
"Docker containers share the host OS kernel."
]
}
facts = topics_data.get(TOPIC, ["Generic fact 1", "Generic fact 2"])
for i, text in enumerate(facts):
vec = embedder.encode(text)
KNOWLEDGE_BASE.append({
"id": f"{SHARD_ID}_{i}",
"text": text,
"vector": vec,
"metadata": {"centrality": 0.9, "recency": 1.0, "weight": 1.0}
})
print(f"✅ Loaded {len(KNOWLEDGE_BASE)} facts.")
@app.on_event("startup")
async def startup_event():
seed_knowledge()
# --- API ---
class RetrievalRequest(BaseModel):
query_text: str
query_vector: List[float] = None # Optional, if router computed it
class RetrievalResponse(BaseModel):
shard_id: str
best_text: str
score: float
@app.post("/retrieve", response_model=RetrievalResponse)
async def retrieve(req: RetrievalRequest):
# 1. Get Query Vector
if req.query_vector:
q_vec = np.array(req.query_vector)
else:
q_vec = embedder.encode(req.query_text)
# 2. Vector Search (Dot Product)
# Stack all vectors
db_vecs = np.array([item["vector"] for item in KNOWLEDGE_BASE])
# Cosine Sim
norm_q = np.linalg.norm(q_vec)
norm_db = np.linalg.norm(db_vecs, axis=1)
dot = np.dot(db_vecs, q_vec)
sims = dot / (norm_q * norm_db + 1e-9)
# 3. Apply Metadata (P = S * C * R * W)
best_score = -1.0
best_idx = 0
for i, sim in enumerate(sims):
meta = KNOWLEDGE_BASE[i]["metadata"]
p_score = sim * meta["centrality"] * meta["recency"] * meta["weight"]
if p_score > best_score:
best_score = p_score
best_idx = i
return RetrievalResponse(
shard_id=SHARD_ID,
best_text=KNOWLEDGE_BASE[best_idx]["text"],
score=float(best_score)
)
if __name__ == "__main__":
port = int(os.getenv("PORT", 8001))
uvicorn.run(app, host="0.0.0.0", port=port)
|