| import os |
| import numpy as np |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from typing import List, Dict |
| from sentence_transformers import SentenceTransformer |
| import uvicorn |
|
|
| app = FastAPI() |
|
|
| |
| SHARD_ID = os.getenv("SHARD_ID", "shard_0") |
| TOPIC = os.getenv("SHARD_TOPIC", "General") |
| print(f"๐ [Shard {SHARD_ID}] Initializing... Topic: {TOPIC}") |
|
|
| |
| |
| |
| embedder = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
| |
| KNOWLEDGE_BASE = [] |
|
|
| def seed_knowledge(): |
| print(f"๐ฑ Seeding Knowledge for {TOPIC}...") |
| topics_data = { |
| "Science": [ |
| "The speed of light is 299,792,458 m/s.", |
| "Mitochondria is the powerhouse of the cell.", |
| "Water boils at 100 degrees Celsius at sea level." |
| ], |
| "History": [ |
| "The Roman Empire fell in 476 AD.", |
| "World War II ended in 1945.", |
| "The Great Wall of China was built over centuries." |
| ], |
| "Coding": [ |
| "Python uses indentation for blocks.", |
| "React uses a virtual DOM for performance.", |
| "Docker containers share the host OS kernel." |
| ] |
| } |
| |
| facts = topics_data.get(TOPIC, ["Generic fact 1", "Generic fact 2"]) |
| |
| for i, text in enumerate(facts): |
| vec = embedder.encode(text) |
| KNOWLEDGE_BASE.append({ |
| "id": f"{SHARD_ID}_{i}", |
| "text": text, |
| "vector": vec, |
| "metadata": {"centrality": 0.9, "recency": 1.0, "weight": 1.0} |
| }) |
| print(f"โ
Loaded {len(KNOWLEDGE_BASE)} facts.") |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| seed_knowledge() |
|
|
| |
|
|
| class RetrievalRequest(BaseModel): |
| query_text: str |
| query_vector: List[float] = None |
|
|
| class RetrievalResponse(BaseModel): |
| shard_id: str |
| best_text: str |
| score: float |
|
|
| @app.post("/retrieve", response_model=RetrievalResponse) |
| async def retrieve(req: RetrievalRequest): |
| |
| if req.query_vector: |
| q_vec = np.array(req.query_vector) |
| else: |
| q_vec = embedder.encode(req.query_text) |
| |
| |
| |
| db_vecs = np.array([item["vector"] for item in KNOWLEDGE_BASE]) |
| |
| |
| norm_q = np.linalg.norm(q_vec) |
| norm_db = np.linalg.norm(db_vecs, axis=1) |
| dot = np.dot(db_vecs, q_vec) |
| sims = dot / (norm_q * norm_db + 1e-9) |
| |
| |
| best_score = -1.0 |
| best_idx = 0 |
| |
| for i, sim in enumerate(sims): |
| meta = KNOWLEDGE_BASE[i]["metadata"] |
| p_score = sim * meta["centrality"] * meta["recency"] * meta["weight"] |
| |
| if p_score > best_score: |
| best_score = p_score |
| best_idx = i |
| |
| return RetrievalResponse( |
| shard_id=SHARD_ID, |
| best_text=KNOWLEDGE_BASE[best_idx]["text"], |
| score=float(best_score) |
| ) |
|
|
| if __name__ == "__main__": |
| port = int(os.getenv("PORT", 8001)) |
| uvicorn.run(app, host="0.0.0.0", port=port) |
|
|