Text Generation
English
smriti-memory-ai
smriti-ai
memory
agent-memory
long-term-memory
external-memory
training-free
frozen-model
inference-time-augmentation
retrieval-augmented-generation
rag
semantic-search
knowledge-graph
identity-continuity
small-language-model
small-language-models
ai-agent
gemma
gemma-4
qwen
qwen2.5
llama
llama-3.2
phi-3
Deploy Smriti AI Hugging Face handler
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- README.md +187 -0
- config.json +11 -0
- examples/request_delete.json +6 -0
- examples/request_distractor.json +8 -0
- examples/request_memory_inject.json +11 -0
- examples/request_recall.json +11 -0
- handler.py +647 -0
- requirements.txt +19 -0
- smriti_endpoint_config.yaml +35 -0
- smriti_vendor/mempalace/__init__.py +3 -0
- smriti_vendor/mempalace/__pycache__/__init__.cpython-310.pyc +0 -0
- smriti_vendor/mempalace/__pycache__/agent.cpython-310.pyc +0 -0
- smriti_vendor/mempalace/__pycache__/api.cpython-310.pyc +0 -0
- smriti_vendor/mempalace/__pycache__/cli.cpython-310.pyc +0 -0
- smriti_vendor/mempalace/__pycache__/core.cpython-310.pyc +0 -0
- smriti_vendor/mempalace/__pycache__/gifp.cpython-310.pyc +0 -0
- smriti_vendor/mempalace/__pycache__/identity_fingerprint.cpython-310.pyc +0 -0
- smriti_vendor/mempalace/__pycache__/knowledge_graph.cpython-310.pyc +0 -0
- smriti_vendor/mempalace/__pycache__/macp.cpython-310.pyc +0 -0
- smriti_vendor/mempalace/__pycache__/mem_palace.cpython-310.pyc +0 -0
- smriti_vendor/mempalace/__pycache__/semantic_memory.cpython-310.pyc +0 -0
- smriti_vendor/mempalace/agent.py +3 -0
- smriti_vendor/mempalace/api.py +3 -0
- smriti_vendor/mempalace/cli.py +3 -0
- smriti_vendor/mempalace/core.py +3 -0
- smriti_vendor/mempalace/gifp.py +3 -0
- smriti_vendor/mempalace/identity_fingerprint.py +3 -0
- smriti_vendor/mempalace/knowledge_graph.py +3 -0
- smriti_vendor/mempalace/macp.py +3 -0
- smriti_vendor/mempalace/mem_palace.py +3 -0
- smriti_vendor/mempalace/semantic_memory.py +3 -0
- smriti_vendor/smriti/__init__.py +115 -0
- smriti_vendor/smriti/__main__.py +7 -0
- smriti_vendor/smriti/__pycache__/__init__.cpython-310.pyc +0 -0
- smriti_vendor/smriti/__pycache__/__main__.cpython-310.pyc +0 -0
- smriti_vendor/smriti/__pycache__/agent.cpython-310.pyc +0 -0
- smriti_vendor/smriti/__pycache__/api.cpython-310.pyc +0 -0
- smriti_vendor/smriti/__pycache__/backends.cpython-310.pyc +0 -0
- smriti_vendor/smriti/__pycache__/cli.cpython-310.pyc +0 -0
- smriti_vendor/smriti/__pycache__/config.cpython-310.pyc +0 -0
- smriti_vendor/smriti/__pycache__/core.cpython-310.pyc +0 -0
- smriti_vendor/smriti/__pycache__/gifp.cpython-310.pyc +0 -0
- smriti_vendor/smriti/__pycache__/identity_fingerprint.cpython-310.pyc +0 -0
- smriti_vendor/smriti/__pycache__/knowledge_graph.cpython-310.pyc +0 -0
- smriti_vendor/smriti/__pycache__/macp.cpython-310.pyc +0 -0
- smriti_vendor/smriti/__pycache__/mem_palace.cpython-310.pyc +0 -0
- smriti_vendor/smriti/__pycache__/semantic_memory.cpython-310.pyc +0 -0
- smriti_vendor/smriti/agent.py +262 -0
- smriti_vendor/smriti/api.py +538 -0
- smriti_vendor/smriti/backends.py +494 -0
README.md
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: smriti-ai
|
| 6 |
+
tags:
|
| 7 |
+
- ai-agent
|
| 8 |
+
- memory
|
| 9 |
+
- small-language-models
|
| 10 |
+
- inference-time-augmentation
|
| 11 |
+
- semantic-search
|
| 12 |
+
- knowledge-graph
|
| 13 |
+
- identity-continuity
|
| 14 |
+
- rag
|
| 15 |
+
pipeline_tag: text-generation
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
# Smriti AI
|
| 19 |
+
|
| 20 |
+
## What this is
|
| 21 |
+
|
| 22 |
+
Smriti AI is a memory-augmented inference layer for small language models. It adds external memory, semantic retrieval, knowledge-graph recall, identity continuity, and privacy-ready memory deletion without changing base model weights.
|
| 23 |
+
|
| 24 |
+
This repository layout is intended for a Hugging Face model-style deployment with a custom `handler.py`. The handler loads a base causal language model or calls a remote model endpoint, wraps it with Smriti AI memory, and returns model responses plus retrieved memories.
|
| 25 |
+
|
| 26 |
+
## What this is not
|
| 27 |
+
|
| 28 |
+
Smriti AI is not a newly trained foundation model. It is not a fine-tuned model unless a separate fine-tuned checkpoint is explicitly included. It is an inference-time wrapper around a base language model.
|
| 29 |
+
|
| 30 |
+
Do not interpret this repository as a standalone model checkpoint. The base model is configured through `BASE_MODEL_ID` or `HF_ENDPOINT_URL`.
|
| 31 |
+
|
| 32 |
+
## Research Lineage
|
| 33 |
+
|
| 34 |
+
Smriti AI follows four principles:
|
| 35 |
+
|
| 36 |
+
- **External memory**: conversational facts live outside model weights in a persistent, inspectable store.
|
| 37 |
+
- **Training-free recall**: relevant facts are retrieved and injected at inference time without fine-tuning the base model.
|
| 38 |
+
- **Identity continuity**: persona evidence is tracked as an embedding fingerprint so outputs can be checked for drift.
|
| 39 |
+
- **Small-model augmentation**: small causal language models can become more useful when paired with explicit memory and retrieval.
|
| 40 |
+
|
| 41 |
+
Historical GodelAI-Lite results were measured on an earlier system. Current Smriti AI results are measured separately and should not be conflated with historical results.
|
| 42 |
+
|
| 43 |
+
## Architecture
|
| 44 |
+
|
| 45 |
+
```text
|
| 46 |
+
User request
|
| 47 |
+
-> Smriti AI handler
|
| 48 |
+
-> memory retrieval
|
| 49 |
+
-> graph retrieval
|
| 50 |
+
-> identity context
|
| 51 |
+
-> base model inference
|
| 52 |
+
-> response
|
| 53 |
+
-> memory write/update
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
The handler supports JSON, SQLite, Redis, and Postgres memory backends. For production, use Redis/Postgres or another external durable store. Do not store private user memory in the Hugging Face model repository.
|
| 57 |
+
|
| 58 |
+
## Supported base models
|
| 59 |
+
|
| 60 |
+
Smriti AI is model-agnostic for Hugging Face causal language models.
|
| 61 |
+
|
| 62 |
+
Supported families depend on the installed `transformers` version and endpoint hardware:
|
| 63 |
+
|
| 64 |
+
- Gemma-style causal LMs when available, including the current benchmark path `google/gemma-4-E2B-it`.
|
| 65 |
+
- Llama/Phi/Mistral/Qwen-style causal LMs if supported by the runtime environment.
|
| 66 |
+
- Tiny CPU-safe local smoke-test models such as `sshleifer/tiny-gpt2` for handler validation only.
|
| 67 |
+
|
| 68 |
+
Tiny models are useful for endpoint plumbing tests. They are not public Smriti AI quality benchmarks.
|
| 69 |
+
|
| 70 |
+
## Evaluation
|
| 71 |
+
|
| 72 |
+
Current local Gemma 4-only benchmark artifacts in the main Smriti AI repository report:
|
| 73 |
+
|
| 74 |
+
| Evaluation | Baseline Recall | Smriti AI Recall | Notes |
|
| 75 |
+
|---|---:|---:|---|
|
| 76 |
+
| Gemma-style three-fact protocol | 0/3 | 3/3 | Smriti AI recalls all injected facts after distractors. |
|
| 77 |
+
| Five-mode comparison | 0/3 | 3/3 | TF-IDF, Semantic, Semantic+Graph, and Semantic+Graph+Identity all recall 3/3 in the checked-in run. |
|
| 78 |
+
| Original broader protocol rerun | 0/3 | 3/3 | Overall average improves from 0.524 to 0.832 (`+58.9%`) in the current local Gemma 4 CPU rerun. |
|
| 79 |
+
|
| 80 |
+
Historical GodelAI-Lite results were measured on an earlier system. Current Smriti AI results are measured separately and should not be conflated with historical results.
|
| 81 |
+
|
| 82 |
+
## Privacy
|
| 83 |
+
|
| 84 |
+
Smriti AI stores user memory. Treat it as user data.
|
| 85 |
+
|
| 86 |
+
- Memory can be encrypted by setting `SMRITI_ENCRYPTION_KEY`.
|
| 87 |
+
- `delete_memory` is supported by the handler.
|
| 88 |
+
- Production deployments should use external memory storage such as Redis/Postgres.
|
| 89 |
+
- Do not store private user memory in the Hugging Face model repository.
|
| 90 |
+
- Public/demo deployments should not receive real PII.
|
| 91 |
+
|
| 92 |
+
## Limitations
|
| 93 |
+
|
| 94 |
+
- Retrieval quality depends on the quality and specificity of stored memory.
|
| 95 |
+
- Public/demo deployments should not receive real PII.
|
| 96 |
+
- Durable memory requires external backend or persistent endpoint storage.
|
| 97 |
+
- Latency depends on the base model, backend, retrieval mode, and endpoint hardware.
|
| 98 |
+
- A tiny CPU demo model validates handler plumbing but will not produce Gemma-quality answers.
|
| 99 |
+
- If no `BASE_MODEL_ID` or `HF_ENDPOINT_URL` is configured, the handler falls back to memory-only responses.
|
| 100 |
+
|
| 101 |
+
## Environment variables
|
| 102 |
+
|
| 103 |
+
| Variable | Purpose |
|
| 104 |
+
|---|---|
|
| 105 |
+
| `BASE_MODEL_ID` | Hugging Face model ID to load inside the endpoint. |
|
| 106 |
+
| `HF_ENDPOINT_URL` | Optional remote model endpoint URL. If set, the handler calls this URL instead of loading a local base model. |
|
| 107 |
+
| `HF_TOKEN` | Token for gated/private base models or protected remote endpoints. |
|
| 108 |
+
| `SMRITI_MEMORY_BACKEND` | `json`, `sqlite`, `redis`, or `postgres`. |
|
| 109 |
+
| `SMRITI_MEMORY_PATH` | JSON user-memory directory or SQLite file path. |
|
| 110 |
+
| `REDIS_URL` | External Redis URL. Takes precedence when present. |
|
| 111 |
+
| `POSTGRES_DSN` | External Postgres DSN. Takes precedence when present and Redis is not configured. |
|
| 112 |
+
| `SMRITI_ENCRYPTION_KEY` | Memory encryption key. Do not commit it. |
|
| 113 |
+
| `SMRITI_RETRIEVAL_MODE` | `tfidf`, `semantic`, `semantic_graph`, or `semantic_graph_identity`. |
|
| 114 |
+
| `SMRITI_PUBLIC_DEMO` | `true` or `false`. Use `true` only for non-PII demos. |
|
| 115 |
+
| `SMRITI_MAX_MEMORY_ENTRIES` | Maximum retained entries per user/topic. |
|
| 116 |
+
|
| 117 |
+
## How to call the endpoint
|
| 118 |
+
|
| 119 |
+
### Chat / fact injection
|
| 120 |
+
|
| 121 |
+
```json
|
| 122 |
+
{
|
| 123 |
+
"inputs": {
|
| 124 |
+
"operation": "chat",
|
| 125 |
+
"user_id": "customer-123",
|
| 126 |
+
"message": "My name is Alex and I am a marine biologist.",
|
| 127 |
+
"retrieval_mode": "semantic_graph_identity"
|
| 128 |
+
},
|
| 129 |
+
"parameters": {
|
| 130 |
+
"max_new_tokens": 256,
|
| 131 |
+
"temperature": 0.7,
|
| 132 |
+
"top_p": 0.9,
|
| 133 |
+
"return_memories": true
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
### Recall
|
| 139 |
+
|
| 140 |
+
```json
|
| 141 |
+
{
|
| 142 |
+
"inputs": {
|
| 143 |
+
"operation": "chat",
|
| 144 |
+
"user_id": "customer-123",
|
| 145 |
+
"message": "What do you remember about me?",
|
| 146 |
+
"retrieval_mode": "semantic_graph_identity"
|
| 147 |
+
},
|
| 148 |
+
"parameters": {
|
| 149 |
+
"return_memories": true
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
### Delete memory
|
| 155 |
+
|
| 156 |
+
```json
|
| 157 |
+
{
|
| 158 |
+
"inputs": {
|
| 159 |
+
"operation": "delete_memory",
|
| 160 |
+
"user_id": "customer-123"
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
### Health
|
| 166 |
+
|
| 167 |
+
```json
|
| 168 |
+
{
|
| 169 |
+
"inputs": {
|
| 170 |
+
"operation": "health"
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
## Local test
|
| 176 |
+
|
| 177 |
+
```bash
|
| 178 |
+
pip install -r requirements.txt
|
| 179 |
+
BASE_MODEL_ID=sshleifer/tiny-gpt2 \
|
| 180 |
+
SMRITI_MEMORY_BACKEND=json \
|
| 181 |
+
SMRITI_MEMORY_PATH=/tmp/smriti_hf_test.json \
|
| 182 |
+
python test_handler_local.py
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
## Custom-container deployment
|
| 186 |
+
|
| 187 |
+
If the standard Hugging Face handler is insufficient for your model size, CUDA libraries, Redis client policy, or enterprise network requirements, deploy the same files in a custom container. Use the main Smriti AI repository Dockerfiles as the starting point, install this handler, and expose a compatible HTTP API through Hugging Face Inference Endpoints custom container support.
|
config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"project": "Smriti AI",
|
| 3 |
+
"base_model": "REPLACE_WITH_BASE_MODEL_ID",
|
| 4 |
+
"retrieval_mode": "semantic_graph_identity",
|
| 5 |
+
"memory_backend": "json",
|
| 6 |
+
"public_demo": false,
|
| 7 |
+
"max_memory_entries": 1000,
|
| 8 |
+
"enable_identity": true,
|
| 9 |
+
"enable_graph": true,
|
| 10 |
+
"enable_encryption": true
|
| 11 |
+
}
|
examples/request_delete.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"inputs": {
|
| 3 |
+
"operation": "delete_memory",
|
| 4 |
+
"user_id": "demo-user"
|
| 5 |
+
}
|
| 6 |
+
}
|
examples/request_distractor.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"inputs": {
|
| 3 |
+
"operation": "chat",
|
| 4 |
+
"user_id": "demo-user",
|
| 5 |
+
"message": "What is the capital of France?",
|
| 6 |
+
"retrieval_mode": "semantic_graph_identity"
|
| 7 |
+
}
|
| 8 |
+
}
|
examples/request_memory_inject.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"inputs": {
|
| 3 |
+
"operation": "chat",
|
| 4 |
+
"user_id": "demo-user",
|
| 5 |
+
"message": "My name is Alex and I am a marine biologist based in Hawaii.",
|
| 6 |
+
"retrieval_mode": "semantic_graph_identity"
|
| 7 |
+
},
|
| 8 |
+
"parameters": {
|
| 9 |
+
"return_memories": true
|
| 10 |
+
}
|
| 11 |
+
}
|
examples/request_recall.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"inputs": {
|
| 3 |
+
"operation": "chat",
|
| 4 |
+
"user_id": "demo-user",
|
| 5 |
+
"message": "What do you remember about me?",
|
| 6 |
+
"retrieval_mode": "semantic_graph_identity"
|
| 7 |
+
},
|
| 8 |
+
"parameters": {
|
| 9 |
+
"return_memories": true
|
| 10 |
+
}
|
| 11 |
+
}
|
handler.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face custom inference handler for Smriti AI.
|
| 2 |
+
|
| 3 |
+
This file is intentionally deployment glue. Core memory, retrieval, graph, and
|
| 4 |
+
identity behavior comes from the installed `smriti` package.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import re
|
| 13 |
+
import sys
|
| 14 |
+
import time
|
| 15 |
+
import urllib.error
|
| 16 |
+
import urllib.request
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from threading import RLock
|
| 19 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 20 |
+
|
| 21 |
+
VENDOR_SRC = Path(__file__).resolve().parent / "smriti_vendor"
|
| 22 |
+
if VENDOR_SRC.exists() and str(VENDOR_SRC) not in sys.path:
|
| 23 |
+
sys.path.insert(0, str(VENDOR_SRC))
|
| 24 |
+
|
| 25 |
+
from smriti import IdentityFingerprint, MemPalaceLite, SmritiAILite # noqa: E402
|
| 26 |
+
from smriti.backends import ( # noqa: E402
|
| 27 |
+
JsonBackend,
|
| 28 |
+
MemoryBackend,
|
| 29 |
+
MemoryCipher,
|
| 30 |
+
PostgresBackend,
|
| 31 |
+
RedisBackend,
|
| 32 |
+
SqliteBackend,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
LOGGER = logging.getLogger("smriti.hf_handler")
|
| 36 |
+
if not LOGGER.handlers:
|
| 37 |
+
logging.basicConfig(level=os.getenv("SMRITI_LOG_LEVEL", "INFO"))
|
| 38 |
+
|
| 39 |
+
DEFAULT_CONFIG = {
|
| 40 |
+
"project": "Smriti AI",
|
| 41 |
+
"base_model": "REPLACE_WITH_BASE_MODEL_ID",
|
| 42 |
+
"retrieval_mode": "semantic_graph_identity",
|
| 43 |
+
"memory_backend": "json",
|
| 44 |
+
"public_demo": False,
|
| 45 |
+
"max_memory_entries": 1000,
|
| 46 |
+
"enable_identity": True,
|
| 47 |
+
"enable_graph": True,
|
| 48 |
+
"enable_encryption": True,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class EndpointHandler:
|
| 53 |
+
"""Hugging Face custom inference endpoint handler."""
|
| 54 |
+
|
| 55 |
+
def __init__(self, path: str = ""):
|
| 56 |
+
self.root = _resolve_root(path)
|
| 57 |
+
self.config = _load_config(self.root / "config.json")
|
| 58 |
+
self.lock = RLock()
|
| 59 |
+
self.memories: Dict[str, MemPalaceLite] = {}
|
| 60 |
+
self.identities: Dict[str, IdentityFingerprint] = {}
|
| 61 |
+
self.backend_warning: Optional[str] = None
|
| 62 |
+
|
| 63 |
+
self.base_model_id = _clean_model_id(
|
| 64 |
+
os.getenv("BASE_MODEL_ID") or self.config.get("base_model", "")
|
| 65 |
+
)
|
| 66 |
+
self.endpoint_url = os.getenv("HF_ENDPOINT_URL", "").strip()
|
| 67 |
+
self.hf_token = os.getenv("HF_TOKEN", "").strip()
|
| 68 |
+
self.default_retrieval_mode = os.getenv(
|
| 69 |
+
"SMRITI_RETRIEVAL_MODE",
|
| 70 |
+
str(self.config.get("retrieval_mode", "semantic_graph_identity")),
|
| 71 |
+
)
|
| 72 |
+
self.max_memory_entries = _int_env(
|
| 73 |
+
"SMRITI_MAX_MEMORY_ENTRIES",
|
| 74 |
+
int(self.config.get("max_memory_entries", 1000)),
|
| 75 |
+
)
|
| 76 |
+
self.public_demo = _bool_env("SMRITI_PUBLIC_DEMO", bool(self.config.get("public_demo", False)))
|
| 77 |
+
self.enable_graph_default = bool(self.config.get("enable_graph", True))
|
| 78 |
+
self.enable_identity_default = bool(self.config.get("enable_identity", True))
|
| 79 |
+
self.enable_encryption = bool(self.config.get("enable_encryption", True))
|
| 80 |
+
|
| 81 |
+
self.backend, self.backend_name = self._init_backend()
|
| 82 |
+
self.model = None
|
| 83 |
+
self.tokenizer = None
|
| 84 |
+
self.device = "cpu"
|
| 85 |
+
if self.endpoint_url:
|
| 86 |
+
LOGGER.info(
|
| 87 |
+
"Smriti AI handler using remote model endpoint; backend=%s retrieval=%s",
|
| 88 |
+
self.backend_name,
|
| 89 |
+
self.default_retrieval_mode,
|
| 90 |
+
)
|
| 91 |
+
elif self.base_model_id:
|
| 92 |
+
self._load_local_model(self.base_model_id)
|
| 93 |
+
else:
|
| 94 |
+
LOGGER.warning(
|
| 95 |
+
"No BASE_MODEL_ID or HF_ENDPOINT_URL configured; handler will run memory-only."
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
LOGGER.info(
|
| 99 |
+
"Smriti AI handler ready: base_model=%s remote_endpoint=%s backend=%s retrieval=%s encryption=%s public_demo=%s",
|
| 100 |
+
self.base_model_id or "memory-only",
|
| 101 |
+
bool(self.endpoint_url),
|
| 102 |
+
self.backend_name,
|
| 103 |
+
self.default_retrieval_mode,
|
| 104 |
+
self.enable_encryption and bool(os.getenv("SMRITI_ENCRYPTION_KEY") or os.getenv("SMRITI_MEMORY_KEY")),
|
| 105 |
+
self.public_demo,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 109 |
+
start = time.perf_counter()
|
| 110 |
+
try:
|
| 111 |
+
inputs, parameters = _normalize_request(data)
|
| 112 |
+
operation = str(inputs.get("operation", "chat")).lower()
|
| 113 |
+
|
| 114 |
+
if operation == "health":
|
| 115 |
+
return self._health(start)
|
| 116 |
+
if operation == "delete_memory":
|
| 117 |
+
return self._delete_memory(inputs, start)
|
| 118 |
+
if operation != "chat":
|
| 119 |
+
return _error(f"Unsupported operation: {operation}", start)
|
| 120 |
+
|
| 121 |
+
return self._chat(inputs, parameters, start)
|
| 122 |
+
except Exception as exc: # Defensive boundary for endpoint runtimes.
|
| 123 |
+
LOGGER.exception("Unhandled Smriti AI handler error")
|
| 124 |
+
return _error(f"handler_error:{exc.__class__.__name__}: {exc}", start)
|
| 125 |
+
|
| 126 |
+
# ------------------------------------------------------------------
|
| 127 |
+
# Operation handlers
|
| 128 |
+
# ------------------------------------------------------------------
|
| 129 |
+
|
| 130 |
+
def _chat(
|
| 131 |
+
self,
|
| 132 |
+
inputs: Dict[str, Any],
|
| 133 |
+
parameters: Dict[str, Any],
|
| 134 |
+
start: float,
|
| 135 |
+
) -> Dict[str, Any]:
|
| 136 |
+
user_id = str(inputs.get("user_id") or "").strip()
|
| 137 |
+
message = str(inputs.get("message") or "").strip()
|
| 138 |
+
topic_id = str(inputs.get("topic_id") or "general").strip() or "general"
|
| 139 |
+
if not user_id:
|
| 140 |
+
return _error("user_id is required", start)
|
| 141 |
+
if not message:
|
| 142 |
+
return _error("message is required for chat operation", start)
|
| 143 |
+
|
| 144 |
+
retrieval_mode = str(inputs.get("retrieval_mode") or self.default_retrieval_mode)
|
| 145 |
+
base_retrieval = _base_retrieval_mode(retrieval_mode)
|
| 146 |
+
include_graph = self.enable_graph_default and "graph" in retrieval_mode
|
| 147 |
+
identity_enabled = self.enable_identity_default and "identity" in retrieval_mode
|
| 148 |
+
|
| 149 |
+
with self.lock:
|
| 150 |
+
memory = self._get_memory(user_id, topic_id, base_retrieval)
|
| 151 |
+
context, retrieved_memories, graph_facts, retrieval_warning = self._retrieve_context(
|
| 152 |
+
memory,
|
| 153 |
+
user_id,
|
| 154 |
+
topic_id,
|
| 155 |
+
message,
|
| 156 |
+
include_graph,
|
| 157 |
+
)
|
| 158 |
+
identity = self._get_identity(user_id, identity_enabled)
|
| 159 |
+
agent = SmritiAILite(
|
| 160 |
+
model=self.model,
|
| 161 |
+
tokenizer=self.tokenizer,
|
| 162 |
+
retrieval_mode=base_retrieval,
|
| 163 |
+
session_id=user_id,
|
| 164 |
+
topic_id=topic_id,
|
| 165 |
+
memory=memory,
|
| 166 |
+
identity=identity,
|
| 167 |
+
auto_device=False,
|
| 168 |
+
)
|
| 169 |
+
agent.build_prompt = lambda user_input: _build_prompt(
|
| 170 |
+
agent,
|
| 171 |
+
memory,
|
| 172 |
+
user_id,
|
| 173 |
+
topic_id,
|
| 174 |
+
user_input,
|
| 175 |
+
include_graph,
|
| 176 |
+
identity_enabled,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
generation_calls = 0
|
| 180 |
+
|
| 181 |
+
def generate(prompt: str, max_tokens: int = 256) -> str:
|
| 182 |
+
nonlocal generation_calls
|
| 183 |
+
generation_calls += 1
|
| 184 |
+
return self._generate_text(prompt, parameters, max_tokens=max_tokens)
|
| 185 |
+
|
| 186 |
+
agent._generate = generate # type: ignore[method-assign]
|
| 187 |
+
try:
|
| 188 |
+
response = agent.chat(message)
|
| 189 |
+
except Exception as exc:
|
| 190 |
+
LOGGER.exception("Model generation failed")
|
| 191 |
+
return _error(f"model_generation_failed:{exc.__class__.__name__}: {exc}", start)
|
| 192 |
+
|
| 193 |
+
response = _stabilize_recall_answer(message, response, retrieved_memories, graph_facts)
|
| 194 |
+
_replace_last_assistant_history(memory, response)
|
| 195 |
+
|
| 196 |
+
identity_check = agent.identity.evaluate_output(response) if identity_enabled else None
|
| 197 |
+
save_warning = self._save_memory(user_id, memory)
|
| 198 |
+
|
| 199 |
+
warnings = [item for item in [self.backend_warning, retrieval_warning, save_warning] if item]
|
| 200 |
+
return {
|
| 201 |
+
"response": response,
|
| 202 |
+
"retrieved_memories": retrieved_memories,
|
| 203 |
+
"graph_facts": graph_facts,
|
| 204 |
+
"identity": {
|
| 205 |
+
"enabled": identity_enabled,
|
| 206 |
+
"drift_score": float(identity_check.distance) if identity_check else 0.0,
|
| 207 |
+
"refinement_triggered": generation_calls > 1,
|
| 208 |
+
},
|
| 209 |
+
"latency_ms": round((time.perf_counter() - start) * 1000, 3),
|
| 210 |
+
"backend": self.backend_name,
|
| 211 |
+
"retrieval_mode": retrieval_mode,
|
| 212 |
+
"warnings": warnings,
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
def _delete_memory(self, inputs: Dict[str, Any], start: float) -> Dict[str, Any]:
|
| 216 |
+
user_id = str(inputs.get("user_id") or "").strip()
|
| 217 |
+
if not user_id:
|
| 218 |
+
return _error("user_id is required for delete_memory operation", start)
|
| 219 |
+
with self.lock:
|
| 220 |
+
existed_cache = self.memories.pop(user_id, None) is not None
|
| 221 |
+
self.identities.pop(user_id, None)
|
| 222 |
+
try:
|
| 223 |
+
deleted_backend = self.backend.delete_user(user_id)
|
| 224 |
+
except Exception as exc:
|
| 225 |
+
LOGGER.exception("Memory backend delete failed")
|
| 226 |
+
return _error(f"backend_delete_failed:{exc.__class__.__name__}: {exc}", start)
|
| 227 |
+
return {
|
| 228 |
+
"deleted": bool(existed_cache or deleted_backend),
|
| 229 |
+
"user_id": user_id,
|
| 230 |
+
"latency_ms": round((time.perf_counter() - start) * 1000, 3),
|
| 231 |
+
"backend": self.backend_name,
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
def _health(self, start: float) -> Dict[str, Any]:
|
| 235 |
+
return {
|
| 236 |
+
"status": "ok",
|
| 237 |
+
"project": "Smriti AI",
|
| 238 |
+
"base_model": self.base_model_id or ("remote-endpoint" if self.endpoint_url else "memory-only"),
|
| 239 |
+
"backend": self.backend_name,
|
| 240 |
+
"retrieval_mode": self.default_retrieval_mode,
|
| 241 |
+
"latency_ms": round((time.perf_counter() - start) * 1000, 3),
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
# ------------------------------------------------------------------
|
| 245 |
+
# Runtime setup
|
| 246 |
+
# ------------------------------------------------------------------
|
| 247 |
+
|
| 248 |
+
def _init_backend(self) -> Tuple[MemoryBackend, str]:
|
| 249 |
+
encryption_key = os.getenv("SMRITI_ENCRYPTION_KEY") or os.getenv("SMRITI_MEMORY_KEY")
|
| 250 |
+
if encryption_key:
|
| 251 |
+
os.environ["SMRITI_MEMORY_KEY"] = encryption_key
|
| 252 |
+
cipher = MemoryCipher(encryption_key if self.enable_encryption else None)
|
| 253 |
+
|
| 254 |
+
redis_url = os.getenv("REDIS_URL") or os.getenv("SMRITI_REDIS_URL")
|
| 255 |
+
postgres_dsn = os.getenv("POSTGRES_DSN") or os.getenv("SMRITI_POSTGRES_DSN")
|
| 256 |
+
selected = (os.getenv("SMRITI_MEMORY_BACKEND") or self.config.get("memory_backend") or "json").lower()
|
| 257 |
+
memory_path = os.getenv("SMRITI_MEMORY_PATH", "/tmp/smriti_hf_memory")
|
| 258 |
+
|
| 259 |
+
if redis_url:
|
| 260 |
+
return RedisBackend(url=redis_url, cipher=cipher), "redis"
|
| 261 |
+
if postgres_dsn:
|
| 262 |
+
return PostgresBackend(dsn=postgres_dsn, cipher=cipher), "postgres"
|
| 263 |
+
if selected == "redis":
|
| 264 |
+
return RedisBackend(url=redis_url or "redis://localhost:6379/0", cipher=cipher), "redis"
|
| 265 |
+
if selected in {"postgres", "postgresql"}:
|
| 266 |
+
return PostgresBackend(dsn=postgres_dsn or "", cipher=cipher), "postgres"
|
| 267 |
+
if selected == "sqlite":
|
| 268 |
+
return SqliteBackend(path=memory_path, cipher=cipher), "sqlite"
|
| 269 |
+
return JsonBackend(root=_json_root(memory_path), cipher=cipher), "json"
|
| 270 |
+
|
| 271 |
+
def _load_local_model(self, model_id: str) -> None:
|
| 272 |
+
try:
|
| 273 |
+
import torch
|
| 274 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 275 |
+
except Exception as exc:
|
| 276 |
+
raise RuntimeError("Install torch and transformers to load a local base model.") from exc
|
| 277 |
+
|
| 278 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 279 |
+
dtype = torch.float32
|
| 280 |
+
if self.device == "cuda":
|
| 281 |
+
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 282 |
+
|
| 283 |
+
kwargs = {"token": self.hf_token} if self.hf_token else {}
|
| 284 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)
|
| 285 |
+
if getattr(self.tokenizer, "pad_token_id", None) is None:
|
| 286 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 287 |
+
try:
|
| 288 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype, **kwargs)
|
| 289 |
+
except TypeError:
|
| 290 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **kwargs)
|
| 291 |
+
self.model.to(self.device)
|
| 292 |
+
self.model.eval()
|
| 293 |
+
LOGGER.info("Loaded local base model %s on %s", model_id, self.device)
|
| 294 |
+
|
| 295 |
+
# ------------------------------------------------------------------
|
| 296 |
+
# Memory and generation helpers
|
| 297 |
+
# ------------------------------------------------------------------
|
| 298 |
+
|
| 299 |
+
def _get_memory(self, user_id: str, topic_id: str, retrieval_mode: str) -> MemPalaceLite:
|
| 300 |
+
self.backend_warning = None
|
| 301 |
+
if user_id not in self.memories:
|
| 302 |
+
state = None
|
| 303 |
+
try:
|
| 304 |
+
state = self.backend.load(user_id)
|
| 305 |
+
except Exception as exc:
|
| 306 |
+
LOGGER.exception("Memory backend load failed; starting empty memory")
|
| 307 |
+
self.backend_warning = f"backend_load_failed:{exc.__class__.__name__}"
|
| 308 |
+
if state:
|
| 309 |
+
memory = MemPalaceLite.from_dict(
|
| 310 |
+
state,
|
| 311 |
+
retrieval_mode=retrieval_mode,
|
| 312 |
+
session_id=user_id,
|
| 313 |
+
topic_id=topic_id,
|
| 314 |
+
max_facts=self.max_memory_entries,
|
| 315 |
+
max_entries_per_topic=self.max_memory_entries,
|
| 316 |
+
)
|
| 317 |
+
else:
|
| 318 |
+
memory = MemPalaceLite(
|
| 319 |
+
retrieval_mode=retrieval_mode,
|
| 320 |
+
session_id=user_id,
|
| 321 |
+
topic_id=topic_id,
|
| 322 |
+
max_facts=self.max_memory_entries,
|
| 323 |
+
max_entries_per_topic=self.max_memory_entries,
|
| 324 |
+
)
|
| 325 |
+
self.memories[user_id] = memory
|
| 326 |
+
memory = self.memories[user_id]
|
| 327 |
+
if memory.retrieval_mode != retrieval_mode:
|
| 328 |
+
memory = MemPalaceLite.from_dict(
|
| 329 |
+
memory.to_dict(),
|
| 330 |
+
retrieval_mode=retrieval_mode,
|
| 331 |
+
session_id=user_id,
|
| 332 |
+
topic_id=topic_id,
|
| 333 |
+
max_facts=self.max_memory_entries,
|
| 334 |
+
max_entries_per_topic=self.max_memory_entries,
|
| 335 |
+
)
|
| 336 |
+
self.memories[user_id] = memory
|
| 337 |
+
memory.session_id = user_id
|
| 338 |
+
memory.topic_id = topic_id
|
| 339 |
+
return memory
|
| 340 |
+
|
| 341 |
+
def _get_identity(self, user_id: str, enabled: bool) -> IdentityFingerprint:
|
| 342 |
+
if user_id not in self.identities:
|
| 343 |
+
threshold = 0.35 if enabled else 2.0
|
| 344 |
+
self.identities[user_id] = IdentityFingerprint(
|
| 345 |
+
role="helpful AI assistant with persistent memory",
|
| 346 |
+
threshold=threshold,
|
| 347 |
+
)
|
| 348 |
+
identity = self.identities[user_id]
|
| 349 |
+
if not enabled:
|
| 350 |
+
identity.threshold = 2.0
|
| 351 |
+
return identity
|
| 352 |
+
|
| 353 |
+
def _retrieve_context(
|
| 354 |
+
self,
|
| 355 |
+
memory: MemPalaceLite,
|
| 356 |
+
user_id: str,
|
| 357 |
+
topic_id: str,
|
| 358 |
+
message: str,
|
| 359 |
+
include_graph: bool,
|
| 360 |
+
) -> Tuple[str, List[str], List[str], Optional[str]]:
|
| 361 |
+
try:
|
| 362 |
+
context = memory.get_context(
|
| 363 |
+
query=message,
|
| 364 |
+
session_id=user_id,
|
| 365 |
+
topic_id=topic_id,
|
| 366 |
+
include_graph=include_graph,
|
| 367 |
+
)
|
| 368 |
+
retrieved_memories = memory.retrieve_facts(
|
| 369 |
+
message,
|
| 370 |
+
k=5,
|
| 371 |
+
session_id=user_id,
|
| 372 |
+
topic_id=topic_id,
|
| 373 |
+
)
|
| 374 |
+
graph_facts = _section_bullets(context, "[RELATED GRAPH FACTS]") if include_graph else []
|
| 375 |
+
return context, retrieved_memories, graph_facts, None
|
| 376 |
+
except Exception as exc:
|
| 377 |
+
LOGGER.exception("Memory retrieval failed")
|
| 378 |
+
return "", [], [], f"retrieval_failed:{exc.__class__.__name__}"
|
| 379 |
+
|
| 380 |
+
def _save_memory(self, user_id: str, memory: MemPalaceLite) -> Optional[str]:
|
| 381 |
+
try:
|
| 382 |
+
self.backend.save(user_id, memory.to_dict())
|
| 383 |
+
return None
|
| 384 |
+
except Exception as exc:
|
| 385 |
+
LOGGER.exception("Memory backend save failed")
|
| 386 |
+
return f"backend_save_failed:{exc.__class__.__name__}"
|
| 387 |
+
|
| 388 |
+
def _generate_text(self, prompt: str, parameters: Dict[str, Any], max_tokens: int = 256) -> str:
|
| 389 |
+
max_new_tokens = int(parameters.get("max_new_tokens", max_tokens) or max_tokens)
|
| 390 |
+
temperature = float(parameters.get("temperature", 0.7))
|
| 391 |
+
top_p = float(parameters.get("top_p", 0.9))
|
| 392 |
+
if self.endpoint_url:
|
| 393 |
+
return self._generate_remote(prompt, max_new_tokens, temperature, top_p)
|
| 394 |
+
if self.model is not None and self.tokenizer is not None:
|
| 395 |
+
return self._generate_local(prompt, max_new_tokens, temperature, top_p)
|
| 396 |
+
return _memory_only_answer(prompt)
|
| 397 |
+
|
| 398 |
+
def _generate_local(
|
| 399 |
+
self,
|
| 400 |
+
prompt: str,
|
| 401 |
+
max_new_tokens: int,
|
| 402 |
+
temperature: float,
|
| 403 |
+
top_p: float,
|
| 404 |
+
) -> str:
|
| 405 |
+
import torch
|
| 406 |
+
|
| 407 |
+
messages = [{"role": "user", "content": prompt}]
|
| 408 |
+
try:
|
| 409 |
+
formatted = self.tokenizer.apply_chat_template(
|
| 410 |
+
messages,
|
| 411 |
+
tokenize=False,
|
| 412 |
+
add_generation_prompt=True,
|
| 413 |
+
)
|
| 414 |
+
except Exception:
|
| 415 |
+
formatted = prompt
|
| 416 |
+
inputs = self.tokenizer(
|
| 417 |
+
formatted,
|
| 418 |
+
return_tensors="pt",
|
| 419 |
+
truncation=True,
|
| 420 |
+
max_length=2048,
|
| 421 |
+
)
|
| 422 |
+
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
| 423 |
+
generate_kwargs = {
|
| 424 |
+
"max_new_tokens": max_new_tokens,
|
| 425 |
+
"do_sample": temperature > 0,
|
| 426 |
+
"pad_token_id": getattr(self.tokenizer, "eos_token_id", None),
|
| 427 |
+
}
|
| 428 |
+
if temperature > 0:
|
| 429 |
+
generate_kwargs["temperature"] = temperature
|
| 430 |
+
generate_kwargs["top_p"] = top_p
|
| 431 |
+
with torch.inference_mode():
|
| 432 |
+
output = self.model.generate(**inputs, **generate_kwargs)
|
| 433 |
+
return self.tokenizer.decode(
|
| 434 |
+
output[0, inputs["input_ids"].shape[1] :].detach().cpu(),
|
| 435 |
+
skip_special_tokens=True,
|
| 436 |
+
).strip()
|
| 437 |
+
|
| 438 |
+
def _generate_remote(
|
| 439 |
+
self,
|
| 440 |
+
prompt: str,
|
| 441 |
+
max_new_tokens: int,
|
| 442 |
+
temperature: float,
|
| 443 |
+
top_p: float,
|
| 444 |
+
) -> str:
|
| 445 |
+
payload = {
|
| 446 |
+
"inputs": prompt,
|
| 447 |
+
"parameters": {
|
| 448 |
+
"max_new_tokens": max_new_tokens,
|
| 449 |
+
"temperature": temperature,
|
| 450 |
+
"top_p": top_p,
|
| 451 |
+
},
|
| 452 |
+
}
|
| 453 |
+
headers = {"Content-Type": "application/json"}
|
| 454 |
+
if self.hf_token:
|
| 455 |
+
headers["Authorization"] = f"Bearer {self.hf_token}"
|
| 456 |
+
request = urllib.request.Request(
|
| 457 |
+
self.endpoint_url,
|
| 458 |
+
data=json.dumps(payload).encode("utf-8"),
|
| 459 |
+
headers=headers,
|
| 460 |
+
method="POST",
|
| 461 |
+
)
|
| 462 |
+
try:
|
| 463 |
+
with urllib.request.urlopen(request, timeout=120) as response: # noqa: S310
|
| 464 |
+
raw = response.read().decode("utf-8")
|
| 465 |
+
except urllib.error.HTTPError as exc:
|
| 466 |
+
body = exc.read().decode("utf-8", errors="replace")
|
| 467 |
+
raise RuntimeError(f"remote endpoint HTTP {exc.code}: {body[:300]}") from exc
|
| 468 |
+
parsed = json.loads(raw)
|
| 469 |
+
return _extract_generated_text(parsed)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
# ----------------------------------------------------------------------
|
| 473 |
+
# Request, context, and formatting helpers
|
| 474 |
+
# ----------------------------------------------------------------------
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def _resolve_root(path: str) -> Path:
|
| 478 |
+
if path:
|
| 479 |
+
root = Path(path).resolve()
|
| 480 |
+
return root.parent if root.is_file() else root
|
| 481 |
+
return Path(__file__).resolve().parent
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def _load_config(path: Path) -> Dict[str, Any]:
|
| 485 |
+
if not path.exists():
|
| 486 |
+
return dict(DEFAULT_CONFIG)
|
| 487 |
+
data = json.loads(path.read_text(encoding="utf-8"))
|
| 488 |
+
config = dict(DEFAULT_CONFIG)
|
| 489 |
+
config.update(data)
|
| 490 |
+
return config
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def _normalize_request(data: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 494 |
+
if not isinstance(data, dict):
|
| 495 |
+
raise ValueError("Request body must be a JSON object.")
|
| 496 |
+
if "inputs" in data:
|
| 497 |
+
inputs = data.get("inputs") or {}
|
| 498 |
+
if isinstance(inputs, str):
|
| 499 |
+
inputs = {"message": inputs}
|
| 500 |
+
parameters = data.get("parameters") or {}
|
| 501 |
+
else:
|
| 502 |
+
inputs = data
|
| 503 |
+
parameters = data.get("parameters") or {}
|
| 504 |
+
if not isinstance(inputs, dict) or not isinstance(parameters, dict):
|
| 505 |
+
raise ValueError("inputs and parameters must be JSON objects.")
|
| 506 |
+
return inputs, parameters
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def _base_retrieval_mode(mode: str) -> str:
|
| 510 |
+
return "tfidf" if str(mode).lower().startswith("tfidf") else "semantic"
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def _build_prompt(
|
| 514 |
+
agent: SmritiAILite,
|
| 515 |
+
memory: MemPalaceLite,
|
| 516 |
+
user_id: str,
|
| 517 |
+
topic_id: str,
|
| 518 |
+
user_input: str,
|
| 519 |
+
include_graph: bool,
|
| 520 |
+
identity_enabled: bool,
|
| 521 |
+
) -> str:
|
| 522 |
+
identity = agent.identity.get_identity_prompt() if identity_enabled else ""
|
| 523 |
+
context = memory.get_context(
|
| 524 |
+
query=user_input,
|
| 525 |
+
session_id=user_id,
|
| 526 |
+
topic_id=topic_id,
|
| 527 |
+
include_graph=include_graph,
|
| 528 |
+
)
|
| 529 |
+
parts = [part for part in [identity.strip(), context.strip(), user_input.strip()] if part]
|
| 530 |
+
return "\n\n".join(parts)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def _section_bullets(context: str, heading: str) -> List[str]:
|
| 534 |
+
if heading not in context:
|
| 535 |
+
return []
|
| 536 |
+
after = context.split(heading, 1)[1]
|
| 537 |
+
chunks = re.split(r"\n\[[A-Z ]+\]", after, maxsplit=1)
|
| 538 |
+
section = chunks[0]
|
| 539 |
+
bullets = []
|
| 540 |
+
for line in section.splitlines():
|
| 541 |
+
cleaned = line.strip()
|
| 542 |
+
if cleaned.startswith("*"):
|
| 543 |
+
bullets.append(cleaned.lstrip("* ").strip())
|
| 544 |
+
return bullets
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def _memory_only_answer(prompt: str) -> str:
|
| 548 |
+
facts = _section_bullets(prompt, "[REMEMBERED FACTS]")
|
| 549 |
+
graph = _section_bullets(prompt, "[RELATED GRAPH FACTS]")
|
| 550 |
+
combined = facts + [item for item in graph if item not in facts]
|
| 551 |
+
if combined:
|
| 552 |
+
return "I remember: " + "; ".join(combined[:5])
|
| 553 |
+
return "Memory updated. No prior relevant context was found."
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def _is_recall_query(message: str) -> bool:
|
| 557 |
+
lowered = message.lower()
|
| 558 |
+
return any(
|
| 559 |
+
phrase in lowered
|
| 560 |
+
for phrase in [
|
| 561 |
+
"remember",
|
| 562 |
+
"what do you know about me",
|
| 563 |
+
"who am i",
|
| 564 |
+
"where do i work",
|
| 565 |
+
"what is my name",
|
| 566 |
+
"what do i do",
|
| 567 |
+
]
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def _stabilize_recall_answer(
|
| 572 |
+
message: str,
|
| 573 |
+
response: str,
|
| 574 |
+
retrieved_memories: List[str],
|
| 575 |
+
graph_facts: List[str],
|
| 576 |
+
) -> str:
|
| 577 |
+
if not _is_recall_query(message):
|
| 578 |
+
return response
|
| 579 |
+
combined = retrieved_memories + [item for item in graph_facts if item not in retrieved_memories]
|
| 580 |
+
if not combined:
|
| 581 |
+
return response
|
| 582 |
+
if _mentions_memory_terms(response, combined):
|
| 583 |
+
return response
|
| 584 |
+
return "I remember: " + "; ".join(combined[:5])
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def _mentions_memory_terms(response: str, memories: List[str]) -> bool:
|
| 588 |
+
response_terms = set(re.findall(r"[a-z0-9']{4,}", response.lower()))
|
| 589 |
+
memory_terms = set()
|
| 590 |
+
for memory in memories:
|
| 591 |
+
memory_terms.update(re.findall(r"[a-z0-9']{4,}", memory.lower()))
|
| 592 |
+
return bool(response_terms & memory_terms)
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def _replace_last_assistant_history(memory: MemPalaceLite, response: str) -> None:
|
| 596 |
+
if memory.history and memory.history[-1].category == "assistant_output":
|
| 597 |
+
memory.history[-1].content = "Assistant: " + response[:200]
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def _extract_generated_text(parsed: Any) -> str:
|
| 601 |
+
if isinstance(parsed, list) and parsed:
|
| 602 |
+
return _extract_generated_text(parsed[0])
|
| 603 |
+
if isinstance(parsed, dict):
|
| 604 |
+
for key in ["generated_text", "response", "text", "output"]:
|
| 605 |
+
value = parsed.get(key)
|
| 606 |
+
if isinstance(value, str):
|
| 607 |
+
return value.strip()
|
| 608 |
+
if "outputs" in parsed:
|
| 609 |
+
return _extract_generated_text(parsed["outputs"])
|
| 610 |
+
if isinstance(parsed, str):
|
| 611 |
+
return parsed.strip()
|
| 612 |
+
raise RuntimeError("Remote endpoint did not return generated text.")
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def _json_root(memory_path: str) -> Path:
|
| 616 |
+
path = Path(memory_path)
|
| 617 |
+
if path.suffix.lower() in {".json", ".jsonl"}:
|
| 618 |
+
return path.with_suffix("")
|
| 619 |
+
return path
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
def _clean_model_id(value: str) -> str:
|
| 623 |
+
value = (value or "").strip()
|
| 624 |
+
if not value or value == "REPLACE_WITH_BASE_MODEL_ID":
|
| 625 |
+
return ""
|
| 626 |
+
return value
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def _bool_env(name: str, default: bool) -> bool:
|
| 630 |
+
raw = os.getenv(name)
|
| 631 |
+
if raw is None:
|
| 632 |
+
return default
|
| 633 |
+
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def _int_env(name: str, default: int) -> int:
|
| 637 |
+
try:
|
| 638 |
+
return int(os.getenv(name, str(default)))
|
| 639 |
+
except ValueError:
|
| 640 |
+
return default
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def _error(message: str, start: float) -> Dict[str, Any]:
|
| 644 |
+
return {
|
| 645 |
+
"error": message,
|
| 646 |
+
"latency_ms": round((time.perf_counter() - start) * 1000, 3),
|
| 647 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Smriti AI is not yet assumed to be published on PyPI for this deployment artifact.
|
| 2 |
+
# Until it is published, install the package directly from the GitHub repository.
|
| 3 |
+
git+https://github.com/Luciferai04/smriti-ai.git
|
| 4 |
+
|
| 5 |
+
# After PyPI publication, replace the GitHub line above with:
|
| 6 |
+
# smriti-ai>=0.3.1
|
| 7 |
+
|
| 8 |
+
transformers
|
| 9 |
+
accelerate
|
| 10 |
+
torch
|
| 11 |
+
sentence-transformers
|
| 12 |
+
faiss-cpu
|
| 13 |
+
networkx
|
| 14 |
+
cryptography
|
| 15 |
+
pydantic
|
| 16 |
+
redis
|
| 17 |
+
psycopg2-binary
|
| 18 |
+
huggingface_hub
|
| 19 |
+
requests
|
smriti_endpoint_config.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Smriti AI Hugging Face Inference Endpoint configuration template.
|
| 2 |
+
# Values here are documentation defaults. Set real values as endpoint environment
|
| 3 |
+
# variables or managed secrets, not as committed plaintext.
|
| 4 |
+
|
| 5 |
+
BASE_MODEL_ID: ""
|
| 6 |
+
HF_ENDPOINT_URL: ""
|
| 7 |
+
HF_TOKEN: ""
|
| 8 |
+
SMRITI_MEMORY_BACKEND: "json"
|
| 9 |
+
SMRITI_MEMORY_PATH: "/data/smriti_memory"
|
| 10 |
+
REDIS_URL: ""
|
| 11 |
+
POSTGRES_DSN: ""
|
| 12 |
+
SMRITI_ENCRYPTION_KEY: ""
|
| 13 |
+
SMRITI_RETRIEVAL_MODE: "semantic_graph_identity"
|
| 14 |
+
SMRITI_PUBLIC_DEMO: "false"
|
| 15 |
+
SMRITI_MAX_MEMORY_ENTRIES: "1000"
|
| 16 |
+
|
| 17 |
+
warnings:
|
| 18 |
+
- Do not commit HF_TOKEN.
|
| 19 |
+
- Do not commit SMRITI_ENCRYPTION_KEY.
|
| 20 |
+
- Production memory should use Redis/Postgres or another external durable storage service.
|
| 21 |
+
- The Hugging Face model repository should not contain user memory files.
|
| 22 |
+
- Public demo endpoints should not receive real PII.
|
| 23 |
+
|
| 24 |
+
variables:
|
| 25 |
+
BASE_MODEL_ID: Hugging Face model ID to load locally inside the endpoint.
|
| 26 |
+
HF_ENDPOINT_URL: Optional remote model endpoint URL. If set, Smriti calls it instead of loading BASE_MODEL_ID locally.
|
| 27 |
+
HF_TOKEN: Hugging Face token for gated/private base models or protected remote endpoints.
|
| 28 |
+
SMRITI_MEMORY_BACKEND: json | sqlite | redis | postgres.
|
| 29 |
+
SMRITI_MEMORY_PATH: Path for JSON user-memory directory or SQLite database file.
|
| 30 |
+
REDIS_URL: External Redis URL. Takes precedence when present.
|
| 31 |
+
POSTGRES_DSN: External Postgres DSN. Takes precedence when present and REDIS_URL is empty.
|
| 32 |
+
SMRITI_ENCRYPTION_KEY: Encryption key for user memory. Maps to Smriti's SMRITI_MEMORY_KEY.
|
| 33 |
+
SMRITI_RETRIEVAL_MODE: tfidf | semantic | semantic_graph | semantic_graph_identity.
|
| 34 |
+
SMRITI_PUBLIC_DEMO: true | false. Use true only for non-PII demos.
|
| 35 |
+
SMRITI_MAX_MEMORY_ENTRIES: Maximum fact entries retained per user/topic.
|
smriti_vendor/mempalace/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Backward-compatible imports for the renamed :mod:`smriti` package."""
|
| 2 |
+
|
| 3 |
+
from smriti import * # noqa: F401,F403
|
smriti_vendor/mempalace/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (224 Bytes). View file
|
|
|
smriti_vendor/mempalace/__pycache__/agent.cpython-310.pyc
ADDED
|
Binary file (207 Bytes). View file
|
|
|
smriti_vendor/mempalace/__pycache__/api.cpython-310.pyc
ADDED
|
Binary file (201 Bytes). View file
|
|
|
smriti_vendor/mempalace/__pycache__/cli.cpython-310.pyc
ADDED
|
Binary file (201 Bytes). View file
|
|
|
smriti_vendor/mempalace/__pycache__/core.cpython-310.pyc
ADDED
|
Binary file (204 Bytes). View file
|
|
|
smriti_vendor/mempalace/__pycache__/gifp.cpython-310.pyc
ADDED
|
Binary file (204 Bytes). View file
|
|
|
smriti_vendor/mempalace/__pycache__/identity_fingerprint.cpython-310.pyc
ADDED
|
Binary file (252 Bytes). View file
|
|
|
smriti_vendor/mempalace/__pycache__/knowledge_graph.cpython-310.pyc
ADDED
|
Binary file (237 Bytes). View file
|
|
|
smriti_vendor/mempalace/__pycache__/macp.cpython-310.pyc
ADDED
|
Binary file (204 Bytes). View file
|
|
|
smriti_vendor/mempalace/__pycache__/mem_palace.cpython-310.pyc
ADDED
|
Binary file (222 Bytes). View file
|
|
|
smriti_vendor/mempalace/__pycache__/semantic_memory.cpython-310.pyc
ADDED
|
Binary file (237 Bytes). View file
|
|
|
smriti_vendor/mempalace/agent.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility wrapper for :mod:`smriti.agent`."""
|
| 2 |
+
|
| 3 |
+
from smriti.agent import * # noqa: F401,F403
|
smriti_vendor/mempalace/api.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility wrapper for :mod:`smriti.api`."""
|
| 2 |
+
|
| 3 |
+
from smriti.api import * # noqa: F401,F403
|
smriti_vendor/mempalace/cli.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility wrapper for :mod:`smriti.cli`."""
|
| 2 |
+
|
| 3 |
+
from smriti.cli import * # noqa: F401,F403
|
smriti_vendor/mempalace/core.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility wrapper for :mod:`smriti.core`."""
|
| 2 |
+
|
| 3 |
+
from smriti.core import * # noqa: F401,F403
|
smriti_vendor/mempalace/gifp.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility wrapper for :mod:`smriti.gifp`."""
|
| 2 |
+
|
| 3 |
+
from smriti.gifp import * # noqa: F401,F403
|
smriti_vendor/mempalace/identity_fingerprint.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility wrapper for :mod:`smriti.identity_fingerprint`."""
|
| 2 |
+
|
| 3 |
+
from smriti.identity_fingerprint import * # noqa: F401,F403
|
smriti_vendor/mempalace/knowledge_graph.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility wrapper for :mod:`smriti.knowledge_graph`."""
|
| 2 |
+
|
| 3 |
+
from smriti.knowledge_graph import * # noqa: F401,F403
|
smriti_vendor/mempalace/macp.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility wrapper for :mod:`smriti.macp`."""
|
| 2 |
+
|
| 3 |
+
from smriti.macp import * # noqa: F401,F403
|
smriti_vendor/mempalace/mem_palace.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility wrapper for :mod:`smriti.mem_palace`."""
|
| 2 |
+
|
| 3 |
+
from smriti.mem_palace import * # noqa: F401,F403
|
smriti_vendor/mempalace/semantic_memory.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility wrapper for :mod:`smriti.semantic_memory`."""
|
| 2 |
+
|
| 3 |
+
from smriti.semantic_memory import * # noqa: F401,F403
|
smriti_vendor/smriti/__init__.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Smriti AI — Inference-time memory framework for small language models.
|
| 3 |
+
|
| 4 |
+
Smriti AI adds semantic memory, reasoning continuity, and identity
|
| 5 |
+
governance to any HuggingFace causal LM with zero fine-tuning. The name
|
| 6 |
+
comes from smriti, a Sanskrit term associated with memory and remembrance.
|
| 7 |
+
|
| 8 |
+
Features:
|
| 9 |
+
- Semantic memory with FAISS-based retrieval
|
| 10 |
+
- Knowledge graph integration
|
| 11 |
+
- Embedding-based identity governance (GIFP v1.0)
|
| 12 |
+
- Multi-user support via API and CLI
|
| 13 |
+
|
| 14 |
+
Quick start:
|
| 15 |
+
from smriti import MemPalaceLite, SmritiAILite
|
| 16 |
+
|
| 17 |
+
memory = MemPalaceLite(retrieval_mode="semantic")
|
| 18 |
+
agent = SmritiAILite(model=model, tokenizer=tokenizer)
|
| 19 |
+
reply = agent.chat("My name is Jordan and I am a marine biologist.")
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from .agent import BaselineGemma, GodelAILite, SmritiAILite
|
| 23 |
+
from .backends import (
|
| 24 |
+
JsonBackend,
|
| 25 |
+
MemoryBackend,
|
| 26 |
+
MemoryCipher,
|
| 27 |
+
PostgresBackend,
|
| 28 |
+
RedisBackend,
|
| 29 |
+
SqliteBackend,
|
| 30 |
+
build_backend,
|
| 31 |
+
)
|
| 32 |
+
from .config import SmritiConfig, configure_environment_from_file, load_config, write_default_config
|
| 33 |
+
from .core import MemoryEntry, MemPalaceLite
|
| 34 |
+
from .gifp import GIFPLite
|
| 35 |
+
from .macp import MACPLite, ReasoningStep
|
| 36 |
+
|
| 37 |
+
# New modules for enhanced functionality
|
| 38 |
+
try:
|
| 39 |
+
from .semantic_memory import (
|
| 40 |
+
RetrievalResult,
|
| 41 |
+
SemanticMemory,
|
| 42 |
+
MemoryEntry as SemanticMemoryEntry,
|
| 43 |
+
)
|
| 44 |
+
except ImportError:
|
| 45 |
+
RetrievalResult = None
|
| 46 |
+
SemanticMemory = None
|
| 47 |
+
SemanticMemoryEntry = None
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
from .knowledge_graph import GraphTriple, KnowledgeGraphMemory
|
| 51 |
+
except ImportError:
|
| 52 |
+
GraphTriple = None
|
| 53 |
+
KnowledgeGraphMemory = None
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
from .identity_fingerprint import IdentityCheck, IdentityFingerprint
|
| 57 |
+
except ImportError:
|
| 58 |
+
IdentityCheck = None
|
| 59 |
+
IdentityFingerprint = None
|
| 60 |
+
|
| 61 |
+
__version__ = "0.3.1"
|
| 62 |
+
__author__ = "Alton Lee Wei Bin (creator35lwb)"
|
| 63 |
+
|
| 64 |
+
__all__ = [
|
| 65 |
+
"MemoryEntry",
|
| 66 |
+
"MemPalaceLite",
|
| 67 |
+
"ReasoningStep",
|
| 68 |
+
"MACPLite",
|
| 69 |
+
"GIFPLite",
|
| 70 |
+
"SmritiAILite",
|
| 71 |
+
"GodelAILite",
|
| 72 |
+
"BaselineGemma",
|
| 73 |
+
"MemoryBackend",
|
| 74 |
+
"MemoryCipher",
|
| 75 |
+
"JsonBackend",
|
| 76 |
+
"SqliteBackend",
|
| 77 |
+
"RedisBackend",
|
| 78 |
+
"PostgresBackend",
|
| 79 |
+
"build_backend",
|
| 80 |
+
"SmritiConfig",
|
| 81 |
+
"load_config",
|
| 82 |
+
"configure_environment_from_file",
|
| 83 |
+
"write_default_config",
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
# Add new classes if available
|
| 87 |
+
if SemanticMemory is not None:
|
| 88 |
+
__all__.extend(["SemanticMemory", "SemanticMemoryEntry", "RetrievalResult"])
|
| 89 |
+
if KnowledgeGraphMemory is not None:
|
| 90 |
+
__all__.extend(["KnowledgeGraphMemory", "GraphTriple"])
|
| 91 |
+
if IdentityFingerprint is not None:
|
| 92 |
+
__all__.extend(["IdentityFingerprint", "IdentityCheck"])
|
| 93 |
+
__all__.extend(["api_app", "create_app", "get_memory", "set_agent_factory", "set_memory_backend", "cli_main"])
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def __getattr__(name: str):
|
| 97 |
+
"""Lazy optional API/CLI exports without double-registering Prometheus metrics."""
|
| 98 |
+
|
| 99 |
+
if name in {"api_app", "create_app", "get_memory", "set_agent_factory", "set_memory_backend"}:
|
| 100 |
+
from .api import app as api_app
|
| 101 |
+
from .api import create_app, get_memory, set_agent_factory, set_memory_backend
|
| 102 |
+
|
| 103 |
+
values = {
|
| 104 |
+
"api_app": api_app,
|
| 105 |
+
"create_app": create_app,
|
| 106 |
+
"get_memory": get_memory,
|
| 107 |
+
"set_agent_factory": set_agent_factory,
|
| 108 |
+
"set_memory_backend": set_memory_backend,
|
| 109 |
+
}
|
| 110 |
+
return values[name]
|
| 111 |
+
if name == "cli_main":
|
| 112 |
+
from .cli import main as cli_main
|
| 113 |
+
|
| 114 |
+
return cli_main
|
| 115 |
+
raise AttributeError(name)
|
smriti_vendor/smriti/__main__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run the Smriti AI CLI with `python -m smriti`."""
|
| 2 |
+
|
| 3 |
+
from .cli import main
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
if __name__ == "__main__":
|
| 7 |
+
raise SystemExit(main())
|
smriti_vendor/smriti/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (2.83 kB). View file
|
|
|
smriti_vendor/smriti/__pycache__/__main__.cpython-310.pyc
ADDED
|
Binary file (266 Bytes). View file
|
|
|
smriti_vendor/smriti/__pycache__/agent.cpython-310.pyc
ADDED
|
Binary file (8.03 kB). View file
|
|
|
smriti_vendor/smriti/__pycache__/api.cpython-310.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
smriti_vendor/smriti/__pycache__/backends.cpython-310.pyc
ADDED
|
Binary file (20.5 kB). View file
|
|
|
smriti_vendor/smriti/__pycache__/cli.cpython-310.pyc
ADDED
|
Binary file (8.95 kB). View file
|
|
|
smriti_vendor/smriti/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (5.37 kB). View file
|
|
|
smriti_vendor/smriti/__pycache__/core.cpython-310.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
smriti_vendor/smriti/__pycache__/gifp.cpython-310.pyc
ADDED
|
Binary file (510 Bytes). View file
|
|
|
smriti_vendor/smriti/__pycache__/identity_fingerprint.cpython-310.pyc
ADDED
|
Binary file (9.65 kB). View file
|
|
|
smriti_vendor/smriti/__pycache__/knowledge_graph.cpython-310.pyc
ADDED
|
Binary file (13.3 kB). View file
|
|
|
smriti_vendor/smriti/__pycache__/macp.cpython-310.pyc
ADDED
|
Binary file (2.13 kB). View file
|
|
|
smriti_vendor/smriti/__pycache__/mem_palace.cpython-310.pyc
ADDED
|
Binary file (289 Bytes). View file
|
|
|
smriti_vendor/smriti/__pycache__/semantic_memory.cpython-310.pyc
ADDED
|
Binary file (17.3 kB). View file
|
|
|
smriti_vendor/smriti/agent.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from contextlib import nullcontext
|
| 3 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
from .core import MemPalaceLite
|
| 6 |
+
from .identity_fingerprint import IdentityFingerprint
|
| 7 |
+
from .macp import MACPLite
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import torch
|
| 11 |
+
except Exception:
|
| 12 |
+
torch = None
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from transformers import GenerationConfig
|
| 16 |
+
except Exception:
|
| 17 |
+
GenerationConfig = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SmritiAILite:
|
| 21 |
+
"""
|
| 22 |
+
Model-agnostic SLM wrapper with semantic memory, graph memory, reasoning
|
| 23 |
+
continuity, and GIFP v1.0 identity governance.
|
| 24 |
+
|
| 25 |
+
Pass any pre-loaded HuggingFace causal LM and tokenizer.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
model: Any,
|
| 31 |
+
tokenizer: Any,
|
| 32 |
+
memory_path: Optional[str] = None,
|
| 33 |
+
retrieval_mode: str = "semantic",
|
| 34 |
+
session_id: str = "default",
|
| 35 |
+
topic_id: str = "general",
|
| 36 |
+
memory: Optional[MemPalaceLite] = None,
|
| 37 |
+
identity: Optional[IdentityFingerprint] = None,
|
| 38 |
+
auto_device: bool = True,
|
| 39 |
+
):
|
| 40 |
+
self.model = model
|
| 41 |
+
self.tokenizer = tokenizer
|
| 42 |
+
self.session_id = session_id
|
| 43 |
+
self.topic_id = topic_id
|
| 44 |
+
|
| 45 |
+
if memory is not None:
|
| 46 |
+
self.memory = memory
|
| 47 |
+
elif memory_path and os.path.exists(memory_path):
|
| 48 |
+
self.memory = MemPalaceLite.load(memory_path, retrieval_mode=retrieval_mode)
|
| 49 |
+
else:
|
| 50 |
+
self.memory = MemPalaceLite(
|
| 51 |
+
retrieval_mode=retrieval_mode,
|
| 52 |
+
session_id=session_id,
|
| 53 |
+
topic_id=topic_id,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.continuity = MACPLite()
|
| 57 |
+
self.identity = identity or IdentityFingerprint(
|
| 58 |
+
role="helpful AI assistant with persistent memory"
|
| 59 |
+
)
|
| 60 |
+
self.identity.set_constraints(
|
| 61 |
+
[
|
| 62 |
+
"Always be helpful and accurate",
|
| 63 |
+
"Reference previous context when relevant",
|
| 64 |
+
"Maintain logical consistency across turns",
|
| 65 |
+
"Acknowledge uncertainty when present",
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
self.device, self.autocast_dtype = configure_inference_device()
|
| 69 |
+
if auto_device:
|
| 70 |
+
self._move_model_to_best_device()
|
| 71 |
+
|
| 72 |
+
def build_prompt(self, user_input: str) -> str:
|
| 73 |
+
identity = self.identity.get_identity_prompt()
|
| 74 |
+
ctx = self.memory.get_context(
|
| 75 |
+
query=user_input,
|
| 76 |
+
session_id=self.session_id,
|
| 77 |
+
topic_id=self.topic_id,
|
| 78 |
+
)
|
| 79 |
+
if ctx:
|
| 80 |
+
return identity + "\n" + ctx + "\n\n" + user_input
|
| 81 |
+
return identity + "\n" + user_input
|
| 82 |
+
|
| 83 |
+
def _generate(self, prompt: str, max_tokens: int = 256) -> str:
|
| 84 |
+
if torch is None or GenerationConfig is None:
|
| 85 |
+
raise RuntimeError("torch and transformers are required for model generation.")
|
| 86 |
+
|
| 87 |
+
messages = [{"role": "user", "content": prompt}]
|
| 88 |
+
try:
|
| 89 |
+
formatted = self.tokenizer.apply_chat_template(
|
| 90 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 91 |
+
)
|
| 92 |
+
except Exception:
|
| 93 |
+
formatted = prompt
|
| 94 |
+
|
| 95 |
+
inputs = self.tokenizer(
|
| 96 |
+
formatted,
|
| 97 |
+
return_tensors="pt",
|
| 98 |
+
truncation=True,
|
| 99 |
+
max_length=2048,
|
| 100 |
+
)
|
| 101 |
+
model_device = _model_device(self.model) or self.device
|
| 102 |
+
inputs = {key: value.to(model_device) for key, value in inputs.items()}
|
| 103 |
+
cfg = GenerationConfig(
|
| 104 |
+
max_new_tokens=max_tokens,
|
| 105 |
+
temperature=0.7,
|
| 106 |
+
top_p=0.9,
|
| 107 |
+
do_sample=True,
|
| 108 |
+
pad_token_id=getattr(self.tokenizer, "eos_token_id", None),
|
| 109 |
+
)
|
| 110 |
+
with torch.inference_mode(), _autocast_context(model_device, self.autocast_dtype):
|
| 111 |
+
out = self.model.generate(**inputs, generation_config=cfg)
|
| 112 |
+
return self.tokenizer.decode(
|
| 113 |
+
out[0, inputs["input_ids"].shape[1] :].detach().cpu(),
|
| 114 |
+
skip_special_tokens=True,
|
| 115 |
+
).strip()
|
| 116 |
+
|
| 117 |
+
def chat(self, user_input: str, refine: bool = False) -> str:
|
| 118 |
+
self.continuity.start_chain(user_input)
|
| 119 |
+
self.identity.observe_user_input(user_input)
|
| 120 |
+
prompt = self.build_prompt(user_input)
|
| 121 |
+
response = self._generate(prompt)
|
| 122 |
+
|
| 123 |
+
context = self.memory.get_context(
|
| 124 |
+
query=user_input,
|
| 125 |
+
session_id=self.session_id,
|
| 126 |
+
topic_id=self.topic_id,
|
| 127 |
+
)
|
| 128 |
+
response, identity_check = self.identity.ensure_aligned(
|
| 129 |
+
response,
|
| 130 |
+
self._generate,
|
| 131 |
+
user_input=user_input,
|
| 132 |
+
context=context,
|
| 133 |
+
)
|
| 134 |
+
if refine and identity_check.consistency_score < 0.5:
|
| 135 |
+
response = self.identity.refinement_pass(
|
| 136 |
+
self._generate,
|
| 137 |
+
response,
|
| 138 |
+
user_input=user_input,
|
| 139 |
+
context=context,
|
| 140 |
+
)
|
| 141 |
+
identity_check = self.identity.evaluate_output(response)
|
| 142 |
+
|
| 143 |
+
self.continuity.add_step(
|
| 144 |
+
user_input,
|
| 145 |
+
response,
|
| 146 |
+
identity_check.consistency_score,
|
| 147 |
+
"continue" if identity_check.consistency_score > 0.7 else "refine",
|
| 148 |
+
)
|
| 149 |
+
for fact in self.memory.extract_facts(response, user_input=user_input):
|
| 150 |
+
self.memory.add_fact(fact, session_id=self.session_id, topic_id=self.topic_id)
|
| 151 |
+
self.memory.add_to_history("User: " + user_input, "user_input")
|
| 152 |
+
self.memory.add_to_history(
|
| 153 |
+
"Assistant: " + response[:200], "assistant_output"
|
| 154 |
+
)
|
| 155 |
+
self.identity.record_behavior(response)
|
| 156 |
+
return response
|
| 157 |
+
|
| 158 |
+
def save_memory(self, path: str):
|
| 159 |
+
self.memory.save(path)
|
| 160 |
+
|
| 161 |
+
def load_memory(self, path: str):
|
| 162 |
+
self.memory = MemPalaceLite.load(path, retrieval_mode=self.memory.retrieval_mode)
|
| 163 |
+
|
| 164 |
+
def get_memory_state(self) -> Dict:
|
| 165 |
+
return self.memory.to_dict()
|
| 166 |
+
|
| 167 |
+
def get_reasoning_chain(self) -> str:
|
| 168 |
+
return self.continuity.get_chain_summary()
|
| 169 |
+
|
| 170 |
+
def _move_model_to_best_device(self) -> None:
|
| 171 |
+
if torch is None or self.device is None or str(self.device) == "cpu":
|
| 172 |
+
return
|
| 173 |
+
try:
|
| 174 |
+
current = _model_device(self.model)
|
| 175 |
+
if current is not None and str(current).startswith("cuda"):
|
| 176 |
+
return
|
| 177 |
+
self.model.to(self.device)
|
| 178 |
+
except Exception:
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class BaselineGemma:
|
| 183 |
+
"""Plain causal LM with no memory, no identity layer, no continuity."""
|
| 184 |
+
|
| 185 |
+
def __init__(self, model: Any, tokenizer: Any, auto_device: bool = True):
|
| 186 |
+
self.model = model
|
| 187 |
+
self.tokenizer = tokenizer
|
| 188 |
+
self._history: List[str] = []
|
| 189 |
+
self.device, self.autocast_dtype = configure_inference_device()
|
| 190 |
+
if auto_device and torch is not None and str(self.device) != "cpu":
|
| 191 |
+
try:
|
| 192 |
+
self.model.to(self.device)
|
| 193 |
+
except Exception:
|
| 194 |
+
pass
|
| 195 |
+
|
| 196 |
+
def chat(self, user_input: str) -> str:
|
| 197 |
+
if torch is None or GenerationConfig is None:
|
| 198 |
+
raise RuntimeError("torch and transformers are required for model generation.")
|
| 199 |
+
|
| 200 |
+
messages = [{"role": "user", "content": user_input}]
|
| 201 |
+
try:
|
| 202 |
+
prompt = self.tokenizer.apply_chat_template(
|
| 203 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 204 |
+
)
|
| 205 |
+
except Exception:
|
| 206 |
+
prompt = user_input
|
| 207 |
+
inputs = self.tokenizer(
|
| 208 |
+
prompt,
|
| 209 |
+
return_tensors="pt",
|
| 210 |
+
truncation=True,
|
| 211 |
+
max_length=2048,
|
| 212 |
+
)
|
| 213 |
+
model_device = _model_device(self.model) or self.device
|
| 214 |
+
inputs = {key: value.to(model_device) for key, value in inputs.items()}
|
| 215 |
+
cfg = GenerationConfig(
|
| 216 |
+
max_new_tokens=getattr(self, "max_new_tokens", 256),
|
| 217 |
+
temperature=0.7,
|
| 218 |
+
top_p=0.9,
|
| 219 |
+
do_sample=True,
|
| 220 |
+
pad_token_id=getattr(self.tokenizer, "eos_token_id", None),
|
| 221 |
+
)
|
| 222 |
+
with torch.inference_mode(), _autocast_context(model_device, self.autocast_dtype):
|
| 223 |
+
out = self.model.generate(**inputs, generation_config=cfg)
|
| 224 |
+
response = self.tokenizer.decode(
|
| 225 |
+
out[0, inputs["input_ids"].shape[1] :].detach().cpu(),
|
| 226 |
+
skip_special_tokens=True,
|
| 227 |
+
).strip()
|
| 228 |
+
self._history.extend(["User: " + user_input, "Assistant: " + response])
|
| 229 |
+
return response
|
| 230 |
+
|
| 231 |
+
def reset(self):
|
| 232 |
+
self._history = []
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def configure_inference_device() -> Tuple[Any, Any]:
|
| 236 |
+
"""Return the preferred torch device and mixed-precision dtype."""
|
| 237 |
+
|
| 238 |
+
if torch is None:
|
| 239 |
+
return "cpu", None
|
| 240 |
+
if torch.cuda.is_available():
|
| 241 |
+
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 242 |
+
return torch.device("cuda"), dtype
|
| 243 |
+
return torch.device("cpu"), torch.float32
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def _model_device(model: Any) -> Any:
|
| 247 |
+
try:
|
| 248 |
+
return next(model.parameters()).device
|
| 249 |
+
except Exception:
|
| 250 |
+
return None
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def _autocast_context(device: Any, dtype: Any):
|
| 254 |
+
if torch is None or dtype is None:
|
| 255 |
+
return nullcontext()
|
| 256 |
+
if str(device).startswith("cuda"):
|
| 257 |
+
return torch.autocast(device_type="cuda", dtype=dtype)
|
| 258 |
+
return nullcontext()
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# Backwards compatibility for existing user code.
|
| 262 |
+
GodelAILite = SmritiAILite
|
smriti_vendor/smriti/api.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI layer for multi-user, multi-agent Smriti AI memory access."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import uuid
|
| 10 |
+
from contextlib import contextmanager
|
| 11 |
+
from threading import RLock
|
| 12 |
+
from typing import Any, Callable, Dict, Iterator, List, Optional
|
| 13 |
+
|
| 14 |
+
from fastapi import FastAPI, HTTPException, Request, Response
|
| 15 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 16 |
+
from pydantic import BaseModel, Field
|
| 17 |
+
from prometheus_client import CONTENT_TYPE_LATEST, Counter, Gauge, Histogram, generate_latest
|
| 18 |
+
|
| 19 |
+
from .backends import MemoryBackend, build_backend
|
| 20 |
+
from .config import configure_environment_from_file, load_config
|
| 21 |
+
from .core import MemPalaceLite
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
AgentFactory = Callable[..., Any]
|
| 25 |
+
|
| 26 |
+
USER_MEMORIES: Dict[str, MemPalaceLite] = {}
|
| 27 |
+
MEMORY_LOCK = RLock()
|
| 28 |
+
MEMORY_BACKEND: Optional[MemoryBackend] = None
|
| 29 |
+
AGENT_FACTORY: Optional[AgentFactory] = None
|
| 30 |
+
LOGGER = logging.getLogger("smriti.api")
|
| 31 |
+
LOGGER.setLevel(logging.INFO)
|
| 32 |
+
if not LOGGER.handlers:
|
| 33 |
+
_handler = logging.StreamHandler()
|
| 34 |
+
_handler.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s"))
|
| 35 |
+
LOGGER.addHandler(_handler)
|
| 36 |
+
LOGGER.propagate = False
|
| 37 |
+
|
| 38 |
+
HTTP_REQUESTS = Counter(
|
| 39 |
+
"smriti_http_requests_total",
|
| 40 |
+
"Total HTTP requests handled by the Smriti AI API.",
|
| 41 |
+
("method", "path", "status"),
|
| 42 |
+
)
|
| 43 |
+
HTTP_ERRORS = Counter(
|
| 44 |
+
"smriti_http_errors_total",
|
| 45 |
+
"Total HTTP requests that completed with status >= 500.",
|
| 46 |
+
("method", "path"),
|
| 47 |
+
)
|
| 48 |
+
HTTP_LATENCY = Histogram(
|
| 49 |
+
"smriti_http_request_latency_seconds",
|
| 50 |
+
"End-to-end HTTP request latency.",
|
| 51 |
+
("method", "path"),
|
| 52 |
+
)
|
| 53 |
+
RETRIEVAL_LATENCY = Histogram(
|
| 54 |
+
"smriti_retrieval_latency_seconds",
|
| 55 |
+
"Memory retrieval latency for chat requests.",
|
| 56 |
+
("retrieval_mode",),
|
| 57 |
+
)
|
| 58 |
+
TOKEN_USAGE = Counter(
|
| 59 |
+
"smriti_tokens_total",
|
| 60 |
+
"Approximate whitespace-token count observed by the API.",
|
| 61 |
+
("user_id", "agent_id"),
|
| 62 |
+
)
|
| 63 |
+
USER_MEMORY_COUNT = Gauge(
|
| 64 |
+
"smriti_user_memories",
|
| 65 |
+
"Number of in-memory user memory stores.",
|
| 66 |
+
)
|
| 67 |
+
USER_MEMORY_BYTES = Gauge(
|
| 68 |
+
"smriti_user_memory_bytes",
|
| 69 |
+
"Approximate serialized memory size by user.",
|
| 70 |
+
("user_id",),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class ChatRequest(BaseModel):
|
| 75 |
+
user_id: str
|
| 76 |
+
message: str
|
| 77 |
+
topic_id: str = "general"
|
| 78 |
+
agent_id: str = "executor"
|
| 79 |
+
retrieval_mode: str = "semantic"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ChatResponse(BaseModel):
|
| 83 |
+
user_id: str
|
| 84 |
+
agent_id: str
|
| 85 |
+
topic_id: str
|
| 86 |
+
response: str
|
| 87 |
+
retrieved_context: str
|
| 88 |
+
memory: Dict[str, Any]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class MemoryLoadRequest(BaseModel):
|
| 92 |
+
user_id: str
|
| 93 |
+
memory: Optional[Dict[str, Any]] = None
|
| 94 |
+
path: Optional[str] = None
|
| 95 |
+
retrieval_mode: str = "semantic"
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class MemorySaveRequest(BaseModel):
|
| 99 |
+
user_id: str
|
| 100 |
+
path: Optional[str] = None
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class MemoryDeleteRequest(BaseModel):
|
| 104 |
+
user_id: str
|
| 105 |
+
path: Optional[str] = None
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class GraphQueryRequest(BaseModel):
|
| 109 |
+
user_id: str
|
| 110 |
+
query_entity: str
|
| 111 |
+
topic_id: Optional[str] = None
|
| 112 |
+
depth: int = Field(default=1, ge=1, le=4)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def set_agent_factory(factory: Optional[AgentFactory]) -> None:
|
| 116 |
+
"""
|
| 117 |
+
Register a callable that returns a configured model agent.
|
| 118 |
+
|
| 119 |
+
The callable receives `user_id`, `memory`, `topic_id`, and `agent_id`.
|
| 120 |
+
When no factory is configured, `/chat` runs in memory-only mode.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
global AGENT_FACTORY
|
| 124 |
+
AGENT_FACTORY = factory
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def set_memory_backend(backend: Optional[MemoryBackend]) -> None:
|
| 128 |
+
"""Override the configured persistence backend for tests or deployments."""
|
| 129 |
+
|
| 130 |
+
global MEMORY_BACKEND
|
| 131 |
+
MEMORY_BACKEND = backend
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_memory_backend() -> MemoryBackend:
|
| 135 |
+
"""Return the configured durable backend, constructing it lazily from env."""
|
| 136 |
+
|
| 137 |
+
global MEMORY_BACKEND
|
| 138 |
+
if MEMORY_BACKEND is None:
|
| 139 |
+
configure_environment_from_file()
|
| 140 |
+
MEMORY_BACKEND = build_backend()
|
| 141 |
+
return MEMORY_BACKEND
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_memory(user_id: str, retrieval_mode: str = "semantic") -> MemPalaceLite:
|
| 145 |
+
with MEMORY_LOCK:
|
| 146 |
+
if user_id not in USER_MEMORIES:
|
| 147 |
+
state = None
|
| 148 |
+
try:
|
| 149 |
+
state = get_memory_backend().load(user_id)
|
| 150 |
+
except Exception:
|
| 151 |
+
LOGGER.exception("Durable memory load failed; starting empty memory")
|
| 152 |
+
if state:
|
| 153 |
+
memory = MemPalaceLite.from_dict(state, retrieval_mode=retrieval_mode)
|
| 154 |
+
memory.session_id = user_id
|
| 155 |
+
else:
|
| 156 |
+
memory = MemPalaceLite(
|
| 157 |
+
retrieval_mode=retrieval_mode,
|
| 158 |
+
session_id=user_id,
|
| 159 |
+
)
|
| 160 |
+
USER_MEMORIES[user_id] = memory
|
| 161 |
+
USER_MEMORY_COUNT.set(len(USER_MEMORIES))
|
| 162 |
+
return USER_MEMORIES[user_id]
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def create_app() -> FastAPI:
|
| 166 |
+
config = configure_environment_from_file()
|
| 167 |
+
app = FastAPI(
|
| 168 |
+
title="Smriti AI API",
|
| 169 |
+
version="0.3.1",
|
| 170 |
+
description="Semantic memory, knowledge graph and identity governance API.",
|
| 171 |
+
)
|
| 172 |
+
app.add_middleware(
|
| 173 |
+
CORSMiddleware,
|
| 174 |
+
allow_origins=config.cors_origins,
|
| 175 |
+
allow_credentials=True,
|
| 176 |
+
allow_methods=["*"],
|
| 177 |
+
allow_headers=["*"],
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
@app.middleware("http")
|
| 181 |
+
async def request_observability(request: Request, call_next: Callable[..., Any]) -> Response:
|
| 182 |
+
request_id = request.headers.get("x-request-id", str(uuid.uuid4()))
|
| 183 |
+
path = request.url.path
|
| 184 |
+
start = time.perf_counter()
|
| 185 |
+
status_code = 500
|
| 186 |
+
try:
|
| 187 |
+
_enforce_api_key(request)
|
| 188 |
+
response = await call_next(request)
|
| 189 |
+
status_code = response.status_code
|
| 190 |
+
except HTTPException as exc:
|
| 191 |
+
status_code = exc.status_code
|
| 192 |
+
response = Response(
|
| 193 |
+
content=f'{{"detail":"{exc.detail}"}}',
|
| 194 |
+
status_code=exc.status_code,
|
| 195 |
+
media_type="application/json",
|
| 196 |
+
)
|
| 197 |
+
except Exception:
|
| 198 |
+
LOGGER.exception(
|
| 199 |
+
"Unhandled API request failure",
|
| 200 |
+
extra={"request_id": request_id, "path": path},
|
| 201 |
+
)
|
| 202 |
+
response = Response(
|
| 203 |
+
content='{"detail":"Internal server error"}',
|
| 204 |
+
status_code=500,
|
| 205 |
+
media_type="application/json",
|
| 206 |
+
)
|
| 207 |
+
duration = time.perf_counter() - start
|
| 208 |
+
HTTP_LATENCY.labels(request.method, path).observe(duration)
|
| 209 |
+
HTTP_REQUESTS.labels(request.method, path, str(status_code)).inc()
|
| 210 |
+
if status_code >= 500:
|
| 211 |
+
HTTP_ERRORS.labels(request.method, path).inc()
|
| 212 |
+
USER_MEMORY_COUNT.set(len(USER_MEMORIES))
|
| 213 |
+
response.headers["x-request-id"] = request_id
|
| 214 |
+
LOGGER.info(
|
| 215 |
+
"request completed request_id=%s method=%s path=%s status=%s duration_s=%.6f",
|
| 216 |
+
request_id,
|
| 217 |
+
request.method,
|
| 218 |
+
path,
|
| 219 |
+
status_code,
|
| 220 |
+
duration,
|
| 221 |
+
)
|
| 222 |
+
return response
|
| 223 |
+
|
| 224 |
+
@app.get("/health")
|
| 225 |
+
def health() -> Dict[str, Any]:
|
| 226 |
+
return {"status": "ok", "users": len(USER_MEMORIES)}
|
| 227 |
+
|
| 228 |
+
@app.get("/metrics")
|
| 229 |
+
def metrics() -> Response:
|
| 230 |
+
return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
|
| 231 |
+
|
| 232 |
+
@app.post("/chat", response_model=ChatResponse)
|
| 233 |
+
def chat(request: ChatRequest) -> ChatResponse:
|
| 234 |
+
memory = get_memory(request.user_id, retrieval_mode=request.retrieval_mode)
|
| 235 |
+
memory.retrieval_mode = request.retrieval_mode
|
| 236 |
+
memory.topic_id = request.topic_id
|
| 237 |
+
context, degraded, warnings = _safe_get_context(
|
| 238 |
+
memory,
|
| 239 |
+
query=request.message,
|
| 240 |
+
session_id=request.user_id,
|
| 241 |
+
topic_id=request.topic_id,
|
| 242 |
+
retrieval_mode=request.retrieval_mode,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
if AGENT_FACTORY is not None:
|
| 246 |
+
agent = _build_agent(
|
| 247 |
+
AGENT_FACTORY,
|
| 248 |
+
user_id=request.user_id,
|
| 249 |
+
memory=memory,
|
| 250 |
+
topic_id=request.topic_id,
|
| 251 |
+
agent_id=request.agent_id,
|
| 252 |
+
)
|
| 253 |
+
with MEMORY_LOCK:
|
| 254 |
+
try:
|
| 255 |
+
response = agent.chat(request.message)
|
| 256 |
+
except Exception as exc:
|
| 257 |
+
LOGGER.exception("Agent factory chat failed")
|
| 258 |
+
degraded = True
|
| 259 |
+
warnings.append(f"agent_failure:{exc.__class__.__name__}")
|
| 260 |
+
response = _memory_only_response(context)
|
| 261 |
+
state = memory.to_dict()
|
| 262 |
+
else:
|
| 263 |
+
response = _memory_only_response(context)
|
| 264 |
+
with MEMORY_LOCK:
|
| 265 |
+
_safe_update_memory(
|
| 266 |
+
memory,
|
| 267 |
+
request.message,
|
| 268 |
+
response,
|
| 269 |
+
request.user_id,
|
| 270 |
+
request.topic_id,
|
| 271 |
+
warnings,
|
| 272 |
+
)
|
| 273 |
+
state = memory.to_dict()
|
| 274 |
+
_persist_if_configured(request.user_id, state, warnings)
|
| 275 |
+
|
| 276 |
+
TOKEN_USAGE.labels(request.user_id, request.agent_id).inc(
|
| 277 |
+
_count_tokens(request.message) + _count_tokens(response)
|
| 278 |
+
)
|
| 279 |
+
state["_degraded"] = degraded
|
| 280 |
+
state["_warnings"] = warnings
|
| 281 |
+
return ChatResponse(
|
| 282 |
+
user_id=request.user_id,
|
| 283 |
+
agent_id=request.agent_id,
|
| 284 |
+
topic_id=request.topic_id,
|
| 285 |
+
response=response,
|
| 286 |
+
retrieved_context=context,
|
| 287 |
+
memory=state,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
@app.post("/memory/load")
|
| 291 |
+
def load_memory(request: MemoryLoadRequest) -> Dict[str, Any]:
|
| 292 |
+
with MEMORY_LOCK:
|
| 293 |
+
if request.path:
|
| 294 |
+
memory = MemPalaceLite.load(
|
| 295 |
+
request.path,
|
| 296 |
+
retrieval_mode=request.retrieval_mode,
|
| 297 |
+
)
|
| 298 |
+
elif request.memory:
|
| 299 |
+
memory = MemPalaceLite.from_dict(
|
| 300 |
+
request.memory or {},
|
| 301 |
+
retrieval_mode=request.retrieval_mode,
|
| 302 |
+
)
|
| 303 |
+
else:
|
| 304 |
+
state = get_memory_backend().load(request.user_id)
|
| 305 |
+
if state is None:
|
| 306 |
+
raise HTTPException(status_code=404, detail="No memory found for user.")
|
| 307 |
+
memory = MemPalaceLite.from_dict(
|
| 308 |
+
state,
|
| 309 |
+
retrieval_mode=request.retrieval_mode,
|
| 310 |
+
)
|
| 311 |
+
memory.session_id = request.user_id
|
| 312 |
+
USER_MEMORIES[request.user_id] = memory
|
| 313 |
+
return memory.to_dict()
|
| 314 |
+
|
| 315 |
+
@app.post("/memory/save")
|
| 316 |
+
def save_memory(request: MemorySaveRequest) -> Dict[str, Any]:
|
| 317 |
+
memory = get_memory(request.user_id)
|
| 318 |
+
with MEMORY_LOCK:
|
| 319 |
+
if request.path:
|
| 320 |
+
memory.save(request.path)
|
| 321 |
+
state = memory.to_dict()
|
| 322 |
+
if not request.path:
|
| 323 |
+
get_memory_backend().save(request.user_id, state)
|
| 324 |
+
_observe_memory_size(request.user_id, state)
|
| 325 |
+
return state
|
| 326 |
+
|
| 327 |
+
@app.post("/memory/delete")
|
| 328 |
+
def delete_memory(request: MemoryDeleteRequest) -> Dict[str, Any]:
|
| 329 |
+
with MEMORY_LOCK:
|
| 330 |
+
existed = USER_MEMORIES.pop(request.user_id, None) is not None
|
| 331 |
+
USER_MEMORY_COUNT.set(len(USER_MEMORIES))
|
| 332 |
+
deleted_file = False
|
| 333 |
+
if request.path and os.path.exists(request.path):
|
| 334 |
+
os.remove(request.path)
|
| 335 |
+
deleted_file = True
|
| 336 |
+
deleted_backend = False
|
| 337 |
+
try:
|
| 338 |
+
deleted_backend = get_memory_backend().delete_user(request.user_id)
|
| 339 |
+
except Exception:
|
| 340 |
+
LOGGER.exception("Durable memory deletion failed")
|
| 341 |
+
try:
|
| 342 |
+
USER_MEMORY_BYTES.remove(request.user_id)
|
| 343 |
+
except Exception:
|
| 344 |
+
pass
|
| 345 |
+
return {
|
| 346 |
+
"user_id": request.user_id,
|
| 347 |
+
"deleted_memory": existed,
|
| 348 |
+
"deleted_file": deleted_file,
|
| 349 |
+
"deleted_backend": deleted_backend,
|
| 350 |
+
"remaining_users": len(USER_MEMORIES),
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
@app.post("/graph/query")
|
| 354 |
+
def graph_query(request: GraphQueryRequest) -> Dict[str, Any]:
|
| 355 |
+
memory = get_memory(request.user_id)
|
| 356 |
+
try:
|
| 357 |
+
triples = memory.knowledge_graph.query_graph(
|
| 358 |
+
request.user_id,
|
| 359 |
+
request.query_entity,
|
| 360 |
+
depth=request.depth,
|
| 361 |
+
topic_id=request.topic_id,
|
| 362 |
+
)
|
| 363 |
+
degraded = False
|
| 364 |
+
warnings: List[str] = []
|
| 365 |
+
except Exception as exc:
|
| 366 |
+
LOGGER.exception("Knowledge graph query failed")
|
| 367 |
+
triples = []
|
| 368 |
+
degraded = True
|
| 369 |
+
warnings = [f"knowledge_graph_failure:{exc.__class__.__name__}"]
|
| 370 |
+
return {
|
| 371 |
+
"user_id": request.user_id,
|
| 372 |
+
"query_entity": request.query_entity,
|
| 373 |
+
"triples": [triple.__dict__ for triple in triples],
|
| 374 |
+
"facts": memory.knowledge_graph.triples_to_text(triples),
|
| 375 |
+
"degraded": degraded,
|
| 376 |
+
"warnings": warnings,
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
return app
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def _build_agent(factory: AgentFactory, **kwargs: Any) -> Any:
|
| 383 |
+
try:
|
| 384 |
+
return factory(**kwargs)
|
| 385 |
+
except TypeError:
|
| 386 |
+
return factory(kwargs["memory"])
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def _memory_only_response(context: str) -> str:
|
| 390 |
+
if context:
|
| 391 |
+
bullets = []
|
| 392 |
+
for line in context.splitlines():
|
| 393 |
+
cleaned = line.strip()
|
| 394 |
+
if cleaned.startswith("* "):
|
| 395 |
+
bullets.append(cleaned[2:].strip())
|
| 396 |
+
if bullets:
|
| 397 |
+
rendered = "\n".join(f"- {fact}" for fact in bullets[:5])
|
| 398 |
+
return f"Memory updated. I found relevant context:\n{rendered}"
|
| 399 |
+
return "Memory updated. Relevant context is available for the configured model."
|
| 400 |
+
return "Memory updated. No prior relevant context was found."
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def _safe_get_context(
|
| 404 |
+
memory: MemPalaceLite,
|
| 405 |
+
query: str,
|
| 406 |
+
session_id: str,
|
| 407 |
+
topic_id: str,
|
| 408 |
+
retrieval_mode: str,
|
| 409 |
+
) -> tuple[str, bool, List[str]]:
|
| 410 |
+
warnings: List[str] = []
|
| 411 |
+
with _observe_retrieval(retrieval_mode):
|
| 412 |
+
try:
|
| 413 |
+
return (
|
| 414 |
+
memory.get_context(
|
| 415 |
+
query=query,
|
| 416 |
+
session_id=session_id,
|
| 417 |
+
topic_id=topic_id,
|
| 418 |
+
),
|
| 419 |
+
False,
|
| 420 |
+
warnings,
|
| 421 |
+
)
|
| 422 |
+
except Exception as exc:
|
| 423 |
+
LOGGER.exception("Primary retrieval failed; degrading to TF-IDF/no-graph context")
|
| 424 |
+
warnings.append(f"primary_retrieval_failure:{exc.__class__.__name__}")
|
| 425 |
+
|
| 426 |
+
original_mode = memory.retrieval_mode
|
| 427 |
+
try:
|
| 428 |
+
memory.retrieval_mode = "tfidf"
|
| 429 |
+
with _observe_retrieval("tfidf"):
|
| 430 |
+
context = memory.get_context(
|
| 431 |
+
query=query,
|
| 432 |
+
session_id=session_id,
|
| 433 |
+
topic_id=topic_id,
|
| 434 |
+
include_graph=False,
|
| 435 |
+
)
|
| 436 |
+
warnings.append("degraded_to_tfidf")
|
| 437 |
+
return context, True, warnings
|
| 438 |
+
except Exception as exc:
|
| 439 |
+
LOGGER.exception("Fallback TF-IDF retrieval failed")
|
| 440 |
+
warnings.append(f"fallback_retrieval_failure:{exc.__class__.__name__}")
|
| 441 |
+
return "", True, warnings
|
| 442 |
+
finally:
|
| 443 |
+
memory.retrieval_mode = original_mode
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def _safe_update_memory(
|
| 447 |
+
memory: MemPalaceLite,
|
| 448 |
+
user_message: str,
|
| 449 |
+
response: str,
|
| 450 |
+
session_id: str,
|
| 451 |
+
topic_id: str,
|
| 452 |
+
warnings: List[str],
|
| 453 |
+
) -> None:
|
| 454 |
+
try:
|
| 455 |
+
for fact in memory.extract_facts("", user_input=user_message):
|
| 456 |
+
try:
|
| 457 |
+
memory.add_fact(fact, session_id=session_id, topic_id=topic_id)
|
| 458 |
+
except Exception as exc:
|
| 459 |
+
LOGGER.exception("Fact storage failed")
|
| 460 |
+
warnings.append(f"fact_storage_failure:{exc.__class__.__name__}")
|
| 461 |
+
memory.add_to_history("User: " + user_message, "user_input")
|
| 462 |
+
memory.add_to_history("Assistant: " + response[:200], "assistant_output")
|
| 463 |
+
except Exception as exc:
|
| 464 |
+
LOGGER.exception("Memory update failed")
|
| 465 |
+
warnings.append(f"memory_update_failure:{exc.__class__.__name__}")
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def _persist_if_configured(user_id: str, state: Dict[str, Any], warnings: List[str]) -> None:
|
| 469 |
+
if os.getenv("SMRITI_AUTOSAVE", "0").lower() not in {"1", "true", "yes"}:
|
| 470 |
+
_observe_memory_size(user_id, state)
|
| 471 |
+
return
|
| 472 |
+
try:
|
| 473 |
+
get_memory_backend().save(user_id, state)
|
| 474 |
+
_observe_memory_size(user_id, state)
|
| 475 |
+
except Exception as exc:
|
| 476 |
+
LOGGER.exception("Durable memory autosave failed")
|
| 477 |
+
warnings.append(f"autosave_failure:{exc.__class__.__name__}")
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def _observe_memory_size(user_id: str, state: Dict[str, Any]) -> None:
|
| 481 |
+
try:
|
| 482 |
+
USER_MEMORY_BYTES.labels(user_id).set(len(json.dumps(state)))
|
| 483 |
+
except Exception:
|
| 484 |
+
pass
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
@contextmanager
|
| 488 |
+
def _observe_retrieval(retrieval_mode: str) -> Iterator[None]:
|
| 489 |
+
start = time.perf_counter()
|
| 490 |
+
try:
|
| 491 |
+
yield
|
| 492 |
+
finally:
|
| 493 |
+
RETRIEVAL_LATENCY.labels(retrieval_mode).observe(time.perf_counter() - start)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def _count_tokens(text: str) -> int:
|
| 497 |
+
return max(1, len(text.split())) if text else 0
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def _enforce_api_key(request: Request) -> None:
|
| 501 |
+
expected = os.getenv("SMRITI_API_KEY")
|
| 502 |
+
if not expected:
|
| 503 |
+
return
|
| 504 |
+
if request.url.path in {"/health", "/metrics", "/docs", "/openapi.json"}:
|
| 505 |
+
return
|
| 506 |
+
supplied = request.headers.get("x-api-key")
|
| 507 |
+
if supplied != expected:
|
| 508 |
+
raise HTTPException(status_code=401, detail="Invalid or missing API key.")
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
app = create_app()
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def main(argv: Optional[List[str]] = None) -> None:
|
| 515 |
+
"""Run the API with `python -m smriti.api` or the `smriti-api` entry point."""
|
| 516 |
+
|
| 517 |
+
import argparse
|
| 518 |
+
import uvicorn
|
| 519 |
+
|
| 520 |
+
parser = argparse.ArgumentParser(description="Run the Smriti AI FastAPI service.")
|
| 521 |
+
parser.add_argument("--config", help="Path to config.yaml.")
|
| 522 |
+
parser.add_argument("--host", help="Bind host. Defaults to config or SMRITI_HOST.")
|
| 523 |
+
parser.add_argument("--port", type=int, help="Bind port. Defaults to config or SMRITI_PORT.")
|
| 524 |
+
parser.add_argument("--reload", action="store_true", help="Enable Uvicorn reload mode.")
|
| 525 |
+
args = parser.parse_args(argv)
|
| 526 |
+
if args.config:
|
| 527 |
+
os.environ["SMRITI_CONFIG_PATH"] = args.config
|
| 528 |
+
config = load_config()
|
| 529 |
+
uvicorn.run(
|
| 530 |
+
"smriti.api:app" if args.reload else app,
|
| 531 |
+
host=args.host or os.getenv("SMRITI_HOST", config.host),
|
| 532 |
+
port=args.port or int(os.getenv("SMRITI_PORT", config.port)),
|
| 533 |
+
reload=args.reload,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
if __name__ == "__main__":
|
| 538 |
+
main()
|
smriti_vendor/smriti/backends.py
ADDED
|
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Durable memory backends for Smriti AI.
|
| 2 |
+
|
| 3 |
+
Backends persist complete user memory blobs and also expose a minimal entry API
|
| 4 |
+
for tools that want to store/retrieve lightweight facts without instantiating the
|
| 5 |
+
full runtime. Optional encryption is applied at the blob boundary so JSON, SQL,
|
| 6 |
+
Redis, and Postgres stores share the same privacy behavior.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import base64
|
| 12 |
+
import hashlib
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import re
|
| 16 |
+
import sqlite3
|
| 17 |
+
import time
|
| 18 |
+
from abc import ABC, abstractmethod
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any, Dict, List, Optional
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MemoryBackend(ABC):
|
| 24 |
+
"""Abstract persistence contract for user-isolated Smriti AI memory."""
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def load(self, user_id: str) -> Optional[Dict[str, Any]]:
|
| 28 |
+
"""Load a complete memory state for one user, or None if absent."""
|
| 29 |
+
|
| 30 |
+
@abstractmethod
|
| 31 |
+
def save(self, user_id: str, memory: Dict[str, Any]) -> None:
|
| 32 |
+
"""Persist a complete memory state for one user."""
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def add_entry(
|
| 36 |
+
self,
|
| 37 |
+
user_id: str,
|
| 38 |
+
session_id: str,
|
| 39 |
+
topic_id: str,
|
| 40 |
+
text: str,
|
| 41 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 42 |
+
) -> None:
|
| 43 |
+
"""Persist one lightweight fact/entry for a user/session/topic."""
|
| 44 |
+
|
| 45 |
+
@abstractmethod
|
| 46 |
+
def retrieve(
|
| 47 |
+
self,
|
| 48 |
+
user_id: str,
|
| 49 |
+
session_id: Optional[str] = None,
|
| 50 |
+
topic_id: Optional[str] = None,
|
| 51 |
+
query: str = "",
|
| 52 |
+
k: int = 5,
|
| 53 |
+
) -> List[Dict[str, Any]]:
|
| 54 |
+
"""Retrieve lightweight entries scoped to a user/session/topic."""
|
| 55 |
+
|
| 56 |
+
@abstractmethod
|
| 57 |
+
def delete_user(self, user_id: str) -> bool:
|
| 58 |
+
"""Delete all memory owned by one user. Return whether anything existed."""
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class MemoryCipher:
|
| 62 |
+
"""Optional symmetric encryption wrapper using Fernet when configured."""
|
| 63 |
+
|
| 64 |
+
def __init__(self, secret: Optional[str] = None):
|
| 65 |
+
self.secret = secret or os.getenv("SMRITI_MEMORY_KEY")
|
| 66 |
+
self._fernet = None
|
| 67 |
+
if self.secret:
|
| 68 |
+
try:
|
| 69 |
+
from cryptography.fernet import Fernet
|
| 70 |
+
except Exception as exc: # pragma: no cover - depends on optional install.
|
| 71 |
+
raise RuntimeError(
|
| 72 |
+
"SMRITI_MEMORY_KEY is set, but cryptography is not installed. "
|
| 73 |
+
"Install smriti-ai[security] or smriti-ai[full]."
|
| 74 |
+
) from exc
|
| 75 |
+
self._fernet = Fernet(_fernet_key(self.secret))
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def enabled(self) -> bool:
|
| 79 |
+
return self._fernet is not None
|
| 80 |
+
|
| 81 |
+
def wrap(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
| 82 |
+
if not self._fernet:
|
| 83 |
+
return {"encrypted": False, "payload": payload}
|
| 84 |
+
raw = json.dumps(payload, sort_keys=True).encode("utf-8")
|
| 85 |
+
return {
|
| 86 |
+
"encrypted": True,
|
| 87 |
+
"algorithm": "fernet-sha256-derived-key",
|
| 88 |
+
"payload": self._fernet.encrypt(raw).decode("utf-8"),
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
def unwrap(self, wrapper: Dict[str, Any]) -> Dict[str, Any]:
|
| 92 |
+
if not wrapper.get("encrypted"):
|
| 93 |
+
return dict(wrapper.get("payload", {}))
|
| 94 |
+
if not self._fernet:
|
| 95 |
+
raise RuntimeError("Memory blob is encrypted but SMRITI_MEMORY_KEY is not configured.")
|
| 96 |
+
decrypted = self._fernet.decrypt(wrapper["payload"].encode("utf-8"))
|
| 97 |
+
return json.loads(decrypted.decode("utf-8"))
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def build_backend(kind: Optional[str] = None, **kwargs: Any) -> MemoryBackend:
|
| 101 |
+
"""Construct a backend from an explicit kind or SMRITI_MEMORY_BACKEND."""
|
| 102 |
+
|
| 103 |
+
selected = (kind or os.getenv("SMRITI_MEMORY_BACKEND") or "json").lower()
|
| 104 |
+
if selected == "json":
|
| 105 |
+
return JsonBackend(root=kwargs.get("root") or os.getenv("SMRITI_MEMORY_DIR", "data/memory"))
|
| 106 |
+
if selected == "sqlite":
|
| 107 |
+
return SqliteBackend(path=kwargs.get("path") or os.getenv("SMRITI_SQLITE_PATH", "data/smriti_memory.sqlite3"))
|
| 108 |
+
if selected == "redis":
|
| 109 |
+
return RedisBackend(url=kwargs.get("url") or os.getenv("SMRITI_REDIS_URL", "redis://localhost:6379/0"))
|
| 110 |
+
if selected in {"postgres", "postgresql"}:
|
| 111 |
+
return PostgresBackend(dsn=kwargs.get("dsn") or os.getenv("SMRITI_POSTGRES_DSN", ""))
|
| 112 |
+
raise ValueError("SMRITI_MEMORY_BACKEND must be one of: json, sqlite, redis, postgres.")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class JsonBackend(MemoryBackend):
|
| 116 |
+
"""File-per-user JSON backend. This preserves the original local behavior."""
|
| 117 |
+
|
| 118 |
+
def __init__(self, root: str | Path = "data/memory", cipher: Optional[MemoryCipher] = None):
|
| 119 |
+
self.root = Path(root)
|
| 120 |
+
self.cipher = cipher or MemoryCipher()
|
| 121 |
+
|
| 122 |
+
def load(self, user_id: str) -> Optional[Dict[str, Any]]:
|
| 123 |
+
path = self._path(user_id)
|
| 124 |
+
if not path.exists():
|
| 125 |
+
return None
|
| 126 |
+
return self.cipher.unwrap(json.loads(path.read_text(encoding="utf-8")))
|
| 127 |
+
|
| 128 |
+
def save(self, user_id: str, memory: Dict[str, Any]) -> None:
|
| 129 |
+
self.root.mkdir(parents=True, exist_ok=True)
|
| 130 |
+
self._path(user_id).write_text(
|
| 131 |
+
json.dumps(self.cipher.wrap(memory), indent=2),
|
| 132 |
+
encoding="utf-8",
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def add_entry(
|
| 136 |
+
self,
|
| 137 |
+
user_id: str,
|
| 138 |
+
session_id: str,
|
| 139 |
+
topic_id: str,
|
| 140 |
+
text: str,
|
| 141 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 142 |
+
) -> None:
|
| 143 |
+
state = self.load(user_id) or {"backend_entries": []}
|
| 144 |
+
state.setdefault("backend_entries", []).append(_entry(session_id, topic_id, text, metadata))
|
| 145 |
+
self.save(user_id, state)
|
| 146 |
+
|
| 147 |
+
def retrieve(
|
| 148 |
+
self,
|
| 149 |
+
user_id: str,
|
| 150 |
+
session_id: Optional[str] = None,
|
| 151 |
+
topic_id: Optional[str] = None,
|
| 152 |
+
query: str = "",
|
| 153 |
+
k: int = 5,
|
| 154 |
+
) -> List[Dict[str, Any]]:
|
| 155 |
+
state = self.load(user_id) or {}
|
| 156 |
+
return _rank_entries(state.get("backend_entries", []), session_id, topic_id, query, k)
|
| 157 |
+
|
| 158 |
+
def delete_user(self, user_id: str) -> bool:
|
| 159 |
+
path = self._path(user_id)
|
| 160 |
+
existed = path.exists()
|
| 161 |
+
if existed:
|
| 162 |
+
path.unlink()
|
| 163 |
+
return existed
|
| 164 |
+
|
| 165 |
+
def _path(self, user_id: str) -> Path:
|
| 166 |
+
return self.root / f"{_safe_id(user_id)}.json"
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class SqliteBackend(MemoryBackend):
|
| 170 |
+
"""SQLite backend for local durable multi-user memory."""
|
| 171 |
+
|
| 172 |
+
def __init__(self, path: str | Path = "data/smriti_memory.sqlite3", cipher: Optional[MemoryCipher] = None):
|
| 173 |
+
self.path = Path(path)
|
| 174 |
+
self.cipher = cipher or MemoryCipher()
|
| 175 |
+
self._init_schema()
|
| 176 |
+
|
| 177 |
+
def load(self, user_id: str) -> Optional[Dict[str, Any]]:
|
| 178 |
+
with self._connect() as conn:
|
| 179 |
+
row = conn.execute("SELECT payload FROM user_memory WHERE user_id = ?", (user_id,)).fetchone()
|
| 180 |
+
if not row:
|
| 181 |
+
return None
|
| 182 |
+
return self.cipher.unwrap(json.loads(row[0]))
|
| 183 |
+
|
| 184 |
+
def save(self, user_id: str, memory: Dict[str, Any]) -> None:
|
| 185 |
+
payload = json.dumps(self.cipher.wrap(memory))
|
| 186 |
+
with self._connect() as conn:
|
| 187 |
+
conn.execute(
|
| 188 |
+
"""
|
| 189 |
+
INSERT INTO user_memory(user_id, payload, updated_at)
|
| 190 |
+
VALUES(?, ?, ?)
|
| 191 |
+
ON CONFLICT(user_id) DO UPDATE SET payload=excluded.payload, updated_at=excluded.updated_at
|
| 192 |
+
""",
|
| 193 |
+
(user_id, payload, time.time()),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def add_entry(
|
| 197 |
+
self,
|
| 198 |
+
user_id: str,
|
| 199 |
+
session_id: str,
|
| 200 |
+
topic_id: str,
|
| 201 |
+
text: str,
|
| 202 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 203 |
+
) -> None:
|
| 204 |
+
with self._connect() as conn:
|
| 205 |
+
conn.execute(
|
| 206 |
+
"""
|
| 207 |
+
INSERT INTO memory_entries(user_id, session_id, topic_id, text, metadata, created_at)
|
| 208 |
+
VALUES(?, ?, ?, ?, ?, ?)
|
| 209 |
+
""",
|
| 210 |
+
(user_id, session_id, topic_id, text, json.dumps(metadata or {}), time.time()),
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def retrieve(
|
| 214 |
+
self,
|
| 215 |
+
user_id: str,
|
| 216 |
+
session_id: Optional[str] = None,
|
| 217 |
+
topic_id: Optional[str] = None,
|
| 218 |
+
query: str = "",
|
| 219 |
+
k: int = 5,
|
| 220 |
+
) -> List[Dict[str, Any]]:
|
| 221 |
+
clauses = ["user_id = ?"]
|
| 222 |
+
params: List[Any] = [user_id]
|
| 223 |
+
if session_id:
|
| 224 |
+
clauses.append("session_id = ?")
|
| 225 |
+
params.append(session_id)
|
| 226 |
+
if topic_id:
|
| 227 |
+
clauses.append("topic_id = ?")
|
| 228 |
+
params.append(topic_id)
|
| 229 |
+
params.append(max(1, k * 5))
|
| 230 |
+
sql = f"""
|
| 231 |
+
SELECT session_id, topic_id, text, metadata, created_at
|
| 232 |
+
FROM memory_entries
|
| 233 |
+
WHERE {' AND '.join(clauses)}
|
| 234 |
+
ORDER BY created_at DESC
|
| 235 |
+
LIMIT ?
|
| 236 |
+
"""
|
| 237 |
+
with self._connect() as conn:
|
| 238 |
+
rows = conn.execute(sql, params).fetchall()
|
| 239 |
+
entries = [
|
| 240 |
+
{
|
| 241 |
+
"session_id": row[0],
|
| 242 |
+
"topic_id": row[1],
|
| 243 |
+
"text": row[2],
|
| 244 |
+
"metadata": json.loads(row[3] or "{}"),
|
| 245 |
+
"created_at": row[4],
|
| 246 |
+
}
|
| 247 |
+
for row in rows
|
| 248 |
+
]
|
| 249 |
+
return _rank_entries(entries, session_id, topic_id, query, k)
|
| 250 |
+
|
| 251 |
+
def delete_user(self, user_id: str) -> bool:
|
| 252 |
+
with self._connect() as conn:
|
| 253 |
+
before = conn.total_changes
|
| 254 |
+
conn.execute("DELETE FROM user_memory WHERE user_id = ?", (user_id,))
|
| 255 |
+
conn.execute("DELETE FROM memory_entries WHERE user_id = ?", (user_id,))
|
| 256 |
+
return conn.total_changes > before
|
| 257 |
+
|
| 258 |
+
def _connect(self) -> sqlite3.Connection:
|
| 259 |
+
self.path.parent.mkdir(parents=True, exist_ok=True)
|
| 260 |
+
return sqlite3.connect(self.path)
|
| 261 |
+
|
| 262 |
+
def _init_schema(self) -> None:
|
| 263 |
+
with self._connect() as conn:
|
| 264 |
+
conn.execute(
|
| 265 |
+
"""
|
| 266 |
+
CREATE TABLE IF NOT EXISTS user_memory(
|
| 267 |
+
user_id TEXT PRIMARY KEY,
|
| 268 |
+
payload TEXT NOT NULL,
|
| 269 |
+
updated_at REAL NOT NULL
|
| 270 |
+
)
|
| 271 |
+
"""
|
| 272 |
+
)
|
| 273 |
+
conn.execute(
|
| 274 |
+
"""
|
| 275 |
+
CREATE TABLE IF NOT EXISTS memory_entries(
|
| 276 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 277 |
+
user_id TEXT NOT NULL,
|
| 278 |
+
session_id TEXT NOT NULL,
|
| 279 |
+
topic_id TEXT NOT NULL,
|
| 280 |
+
text TEXT NOT NULL,
|
| 281 |
+
metadata TEXT NOT NULL,
|
| 282 |
+
created_at REAL NOT NULL
|
| 283 |
+
)
|
| 284 |
+
"""
|
| 285 |
+
)
|
| 286 |
+
conn.execute(
|
| 287 |
+
"CREATE INDEX IF NOT EXISTS idx_entries_user_session_topic ON memory_entries(user_id, session_id, topic_id, created_at)"
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class RedisBackend(MemoryBackend): # pragma: no cover - requires external Redis service.
|
| 292 |
+
"""Redis backend using string payloads and per-user entry lists."""
|
| 293 |
+
|
| 294 |
+
def __init__(self, url: str = "redis://localhost:6379/0", cipher: Optional[MemoryCipher] = None):
|
| 295 |
+
try:
|
| 296 |
+
import redis
|
| 297 |
+
except Exception as exc: # pragma: no cover - optional dependency.
|
| 298 |
+
raise RuntimeError("Install redis to use RedisBackend: pip install smriti-ai[backends]") from exc
|
| 299 |
+
self.client = redis.Redis.from_url(url, decode_responses=True)
|
| 300 |
+
self.cipher = cipher or MemoryCipher()
|
| 301 |
+
|
| 302 |
+
def load(self, user_id: str) -> Optional[Dict[str, Any]]:
|
| 303 |
+
raw = self.client.get(self._payload_key(user_id))
|
| 304 |
+
if not raw:
|
| 305 |
+
return None
|
| 306 |
+
return self.cipher.unwrap(json.loads(raw))
|
| 307 |
+
|
| 308 |
+
def save(self, user_id: str, memory: Dict[str, Any]) -> None:
|
| 309 |
+
self.client.set(self._payload_key(user_id), json.dumps(self.cipher.wrap(memory)))
|
| 310 |
+
|
| 311 |
+
def add_entry(self, user_id: str, session_id: str, topic_id: str, text: str, metadata: Optional[Dict[str, Any]] = None) -> None:
|
| 312 |
+
self.client.lpush(self._entries_key(user_id), json.dumps(_entry(session_id, topic_id, text, metadata)))
|
| 313 |
+
|
| 314 |
+
def retrieve(self, user_id: str, session_id: Optional[str] = None, topic_id: Optional[str] = None, query: str = "", k: int = 5) -> List[Dict[str, Any]]:
|
| 315 |
+
raw_entries = self.client.lrange(self._entries_key(user_id), 0, max(0, k * 5 - 1))
|
| 316 |
+
entries = [json.loads(item) for item in raw_entries]
|
| 317 |
+
return _rank_entries(entries, session_id, topic_id, query, k)
|
| 318 |
+
|
| 319 |
+
def delete_user(self, user_id: str) -> bool:
|
| 320 |
+
return bool(self.client.delete(self._payload_key(user_id), self._entries_key(user_id)))
|
| 321 |
+
|
| 322 |
+
def _payload_key(self, user_id: str) -> str:
|
| 323 |
+
return f"smriti:user:{_safe_id(user_id)}:payload"
|
| 324 |
+
|
| 325 |
+
def _entries_key(self, user_id: str) -> str:
|
| 326 |
+
return f"smriti:user:{_safe_id(user_id)}:entries"
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class PostgresBackend(MemoryBackend): # pragma: no cover - requires external Postgres service.
|
| 330 |
+
"""Postgres backend using psycopg2 and indexed user/session/topic tables."""
|
| 331 |
+
|
| 332 |
+
def __init__(self, dsn: str, cipher: Optional[MemoryCipher] = None):
|
| 333 |
+
if not dsn:
|
| 334 |
+
raise ValueError("SMRITI_POSTGRES_DSN is required for PostgresBackend.")
|
| 335 |
+
try:
|
| 336 |
+
import psycopg2
|
| 337 |
+
except Exception as exc: # pragma: no cover - optional dependency.
|
| 338 |
+
raise RuntimeError("Install psycopg2-binary to use PostgresBackend: pip install smriti-ai[backends]") from exc
|
| 339 |
+
self._psycopg2 = psycopg2
|
| 340 |
+
self.dsn = dsn
|
| 341 |
+
self.cipher = cipher or MemoryCipher()
|
| 342 |
+
self._init_schema()
|
| 343 |
+
|
| 344 |
+
def load(self, user_id: str) -> Optional[Dict[str, Any]]:
|
| 345 |
+
with self._connect() as conn, conn.cursor() as cur:
|
| 346 |
+
cur.execute("SELECT payload FROM user_memory WHERE user_id = %s", (user_id,))
|
| 347 |
+
row = cur.fetchone()
|
| 348 |
+
if not row:
|
| 349 |
+
return None
|
| 350 |
+
return self.cipher.unwrap(row[0] if isinstance(row[0], dict) else json.loads(row[0]))
|
| 351 |
+
|
| 352 |
+
def save(self, user_id: str, memory: Dict[str, Any]) -> None:
|
| 353 |
+
payload = json.dumps(self.cipher.wrap(memory))
|
| 354 |
+
with self._connect() as conn, conn.cursor() as cur:
|
| 355 |
+
cur.execute(
|
| 356 |
+
"""
|
| 357 |
+
INSERT INTO user_memory(user_id, payload, updated_at)
|
| 358 |
+
VALUES(%s, %s::jsonb, NOW())
|
| 359 |
+
ON CONFLICT(user_id) DO UPDATE SET payload=excluded.payload, updated_at=excluded.updated_at
|
| 360 |
+
""",
|
| 361 |
+
(user_id, payload),
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
def add_entry(self, user_id: str, session_id: str, topic_id: str, text: str, metadata: Optional[Dict[str, Any]] = None) -> None:
|
| 365 |
+
with self._connect() as conn, conn.cursor() as cur:
|
| 366 |
+
cur.execute(
|
| 367 |
+
"""
|
| 368 |
+
INSERT INTO memory_entries(user_id, session_id, topic_id, text, metadata)
|
| 369 |
+
VALUES(%s, %s, %s, %s, %s::jsonb)
|
| 370 |
+
""",
|
| 371 |
+
(user_id, session_id, topic_id, text, json.dumps(metadata or {})),
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
def retrieve(self, user_id: str, session_id: Optional[str] = None, topic_id: Optional[str] = None, query: str = "", k: int = 5) -> List[Dict[str, Any]]:
|
| 375 |
+
clauses = ["user_id = %s"]
|
| 376 |
+
params: List[Any] = [user_id]
|
| 377 |
+
if session_id:
|
| 378 |
+
clauses.append("session_id = %s")
|
| 379 |
+
params.append(session_id)
|
| 380 |
+
if topic_id:
|
| 381 |
+
clauses.append("topic_id = %s")
|
| 382 |
+
params.append(topic_id)
|
| 383 |
+
params.append(max(1, k * 5))
|
| 384 |
+
sql = f"""
|
| 385 |
+
SELECT session_id, topic_id, text, metadata, EXTRACT(EPOCH FROM created_at)
|
| 386 |
+
FROM memory_entries
|
| 387 |
+
WHERE {' AND '.join(clauses)}
|
| 388 |
+
ORDER BY created_at DESC
|
| 389 |
+
LIMIT %s
|
| 390 |
+
"""
|
| 391 |
+
with self._connect() as conn, conn.cursor() as cur:
|
| 392 |
+
cur.execute(sql, params)
|
| 393 |
+
rows = cur.fetchall()
|
| 394 |
+
entries = [
|
| 395 |
+
{
|
| 396 |
+
"session_id": row[0],
|
| 397 |
+
"topic_id": row[1],
|
| 398 |
+
"text": row[2],
|
| 399 |
+
"metadata": row[3] or {},
|
| 400 |
+
"created_at": float(row[4]),
|
| 401 |
+
}
|
| 402 |
+
for row in rows
|
| 403 |
+
]
|
| 404 |
+
return _rank_entries(entries, session_id, topic_id, query, k)
|
| 405 |
+
|
| 406 |
+
def delete_user(self, user_id: str) -> bool:
|
| 407 |
+
with self._connect() as conn, conn.cursor() as cur:
|
| 408 |
+
cur.execute("DELETE FROM user_memory WHERE user_id = %s", (user_id,))
|
| 409 |
+
memory_deleted = cur.rowcount
|
| 410 |
+
cur.execute("DELETE FROM memory_entries WHERE user_id = %s", (user_id,))
|
| 411 |
+
entries_deleted = cur.rowcount
|
| 412 |
+
return bool(memory_deleted or entries_deleted)
|
| 413 |
+
|
| 414 |
+
def _connect(self):
|
| 415 |
+
return self._psycopg2.connect(self.dsn)
|
| 416 |
+
|
| 417 |
+
def _init_schema(self) -> None:
|
| 418 |
+
with self._connect() as conn, conn.cursor() as cur:
|
| 419 |
+
cur.execute(
|
| 420 |
+
"""
|
| 421 |
+
CREATE TABLE IF NOT EXISTS user_memory(
|
| 422 |
+
user_id TEXT PRIMARY KEY,
|
| 423 |
+
payload JSONB NOT NULL,
|
| 424 |
+
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
| 425 |
+
)
|
| 426 |
+
"""
|
| 427 |
+
)
|
| 428 |
+
cur.execute(
|
| 429 |
+
"""
|
| 430 |
+
CREATE TABLE IF NOT EXISTS memory_entries(
|
| 431 |
+
id BIGSERIAL PRIMARY KEY,
|
| 432 |
+
user_id TEXT NOT NULL,
|
| 433 |
+
session_id TEXT NOT NULL,
|
| 434 |
+
topic_id TEXT NOT NULL,
|
| 435 |
+
text TEXT NOT NULL,
|
| 436 |
+
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
| 437 |
+
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
| 438 |
+
)
|
| 439 |
+
"""
|
| 440 |
+
)
|
| 441 |
+
cur.execute(
|
| 442 |
+
"CREATE INDEX IF NOT EXISTS idx_smriti_entries_user_session_topic ON memory_entries(user_id, session_id, topic_id, created_at DESC)"
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def _fernet_key(secret: str) -> bytes:
|
| 447 |
+
raw = secret.encode("utf-8")
|
| 448 |
+
try:
|
| 449 |
+
base64.urlsafe_b64decode(raw)
|
| 450 |
+
if len(raw) == 44:
|
| 451 |
+
return raw
|
| 452 |
+
except Exception:
|
| 453 |
+
pass
|
| 454 |
+
return base64.urlsafe_b64encode(hashlib.sha256(raw).digest())
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def _safe_id(value: str) -> str:
|
| 458 |
+
return re.sub(r"[^a-zA-Z0-9_.-]+", "_", value.strip()) or "default"
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def _entry(session_id: str, topic_id: str, text: str, metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
| 462 |
+
return {
|
| 463 |
+
"session_id": session_id,
|
| 464 |
+
"topic_id": topic_id,
|
| 465 |
+
"text": text,
|
| 466 |
+
"metadata": metadata or {},
|
| 467 |
+
"created_at": time.time(),
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def _rank_entries(
|
| 472 |
+
entries: List[Dict[str, Any]],
|
| 473 |
+
session_id: Optional[str],
|
| 474 |
+
topic_id: Optional[str],
|
| 475 |
+
query: str,
|
| 476 |
+
k: int,
|
| 477 |
+
) -> List[Dict[str, Any]]:
|
| 478 |
+
scoped = [
|
| 479 |
+
entry
|
| 480 |
+
for entry in entries
|
| 481 |
+
if (not session_id or entry.get("session_id") == session_id)
|
| 482 |
+
and (not topic_id or entry.get("topic_id") == topic_id)
|
| 483 |
+
]
|
| 484 |
+
if not query.strip():
|
| 485 |
+
return sorted(scoped, key=lambda item: item.get("created_at", 0), reverse=True)[:k]
|
| 486 |
+
q_terms = set(re.findall(r"[a-z0-9']+", query.lower()))
|
| 487 |
+
scored = []
|
| 488 |
+
for entry in scoped:
|
| 489 |
+
terms = set(re.findall(r"[a-z0-9']+", entry.get("text", "").lower()))
|
| 490 |
+
overlap = len(q_terms & terms) / max(1, len(q_terms | terms))
|
| 491 |
+
recency = entry.get("created_at", 0)
|
| 492 |
+
scored.append((overlap, recency, entry))
|
| 493 |
+
scored.sort(key=lambda item: (item[0], item[1]), reverse=True)
|
| 494 |
+
return [entry for _, _, entry in scored[:k]]
|