Spaces:
Sleeping
Sleeping
File size: 4,231 Bytes
dd41762 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """
HTTP API wrapper for Neural Memory.
Provides REST endpoints for the comparison demo.
Run alongside or instead of the MCP server.
"""
import logging
import os
from typing import Any, Dict, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from .config import MemoryConfig
from .memory.neural_memory import NeuralMemory
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize
app = FastAPI(title="Neural Memory API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Configuration
config = MemoryConfig(
dim=int(os.environ.get("MEMORY_DIM", "512")),
learning_rate=float(os.environ.get("LEARNING_RATE", "0.02")),
)
memory = NeuralMemory(config)
logger.info(f"Neural Memory HTTP API initialized: dim={config.dim}, lr={config.learning_rate}")
# Request/Response models
class ObserveRequest(BaseModel):
content: str
learning_rate: Optional[float] = None
class SurpriseRequest(BaseModel):
content: str
class ObserveResponse(BaseModel):
surprise: float
weight_delta: float
learned: bool
weight_hash: str
class SurpriseResponse(BaseModel):
surprise: float
recommendation: str
class StatsResponse(BaseModel):
total_observations: int
weight_parameters: int
avg_surprise: float
learning_rate: float
dimension: int
weight_hash: str
# Endpoints
@app.get("/health")
async def health() -> Dict[str, Any]:
"""Health check endpoint."""
return {
"status": "healthy",
"memory_dim": config.dim,
"parameters": sum(p.numel() for p in memory.parameters()),
}
@app.post("/observe", response_model=ObserveResponse)
async def observe(request: ObserveRequest) -> ObserveResponse:
"""Observe content and trigger learning."""
try:
hash_before = memory.get_weight_hash()
result = memory.observe(request.content, learning_rate=request.learning_rate)
hash_after = memory.get_weight_hash()
return ObserveResponse(
surprise=result["surprise"],
weight_delta=result["weight_delta"],
learned=hash_before != hash_after,
weight_hash=hash_after,
)
except Exception as e:
logger.error(f"Observe error: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@app.post("/surprise", response_model=SurpriseResponse)
async def surprise(request: SurpriseRequest) -> SurpriseResponse:
"""Check surprise without learning."""
try:
score = memory.surprise(request.content)
if score > 0.7:
recommendation = "learn"
elif score < 0.3:
recommendation = "skip"
else:
recommendation = "moderate"
return SurpriseResponse(surprise=score, recommendation=recommendation)
except Exception as e:
logger.error(f"Surprise error: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@app.get("/stats", response_model=StatsResponse)
async def stats() -> StatsResponse:
"""Get memory statistics."""
try:
mem_stats = memory.get_stats()
return StatsResponse(
total_observations=mem_stats["total_observations"],
weight_parameters=mem_stats["weight_parameters"],
avg_surprise=mem_stats["avg_surprise"],
learning_rate=mem_stats["learning_rate"],
dimension=mem_stats["dimension"],
weight_hash=memory.get_weight_hash(),
)
except Exception as e:
logger.error(f"Stats error: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@app.post("/reset")
async def reset() -> Dict[str, str]:
"""Reset memory to initial state."""
global memory
memory = NeuralMemory(config)
return {"status": "reset", "weight_hash": memory.get_weight_hash()}
def main() -> None:
"""Run the HTTP server."""
import uvicorn
port = int(os.environ.get("PORT", "8765"))
uvicorn.run(app, host="0.0.0.0", port=port)
if __name__ == "__main__":
main()
|