luciferai-devil commited on
Commit
316b3f1
·
verified ·
1 Parent(s): 2544d6e

Deploy Smriti AI Hugging Face handler

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +187 -0
  2. config.json +11 -0
  3. examples/request_delete.json +6 -0
  4. examples/request_distractor.json +8 -0
  5. examples/request_memory_inject.json +11 -0
  6. examples/request_recall.json +11 -0
  7. handler.py +647 -0
  8. requirements.txt +19 -0
  9. smriti_endpoint_config.yaml +35 -0
  10. smriti_vendor/mempalace/__init__.py +3 -0
  11. smriti_vendor/mempalace/__pycache__/__init__.cpython-310.pyc +0 -0
  12. smriti_vendor/mempalace/__pycache__/agent.cpython-310.pyc +0 -0
  13. smriti_vendor/mempalace/__pycache__/api.cpython-310.pyc +0 -0
  14. smriti_vendor/mempalace/__pycache__/cli.cpython-310.pyc +0 -0
  15. smriti_vendor/mempalace/__pycache__/core.cpython-310.pyc +0 -0
  16. smriti_vendor/mempalace/__pycache__/gifp.cpython-310.pyc +0 -0
  17. smriti_vendor/mempalace/__pycache__/identity_fingerprint.cpython-310.pyc +0 -0
  18. smriti_vendor/mempalace/__pycache__/knowledge_graph.cpython-310.pyc +0 -0
  19. smriti_vendor/mempalace/__pycache__/macp.cpython-310.pyc +0 -0
  20. smriti_vendor/mempalace/__pycache__/mem_palace.cpython-310.pyc +0 -0
  21. smriti_vendor/mempalace/__pycache__/semantic_memory.cpython-310.pyc +0 -0
  22. smriti_vendor/mempalace/agent.py +3 -0
  23. smriti_vendor/mempalace/api.py +3 -0
  24. smriti_vendor/mempalace/cli.py +3 -0
  25. smriti_vendor/mempalace/core.py +3 -0
  26. smriti_vendor/mempalace/gifp.py +3 -0
  27. smriti_vendor/mempalace/identity_fingerprint.py +3 -0
  28. smriti_vendor/mempalace/knowledge_graph.py +3 -0
  29. smriti_vendor/mempalace/macp.py +3 -0
  30. smriti_vendor/mempalace/mem_palace.py +3 -0
  31. smriti_vendor/mempalace/semantic_memory.py +3 -0
  32. smriti_vendor/smriti/__init__.py +115 -0
  33. smriti_vendor/smriti/__main__.py +7 -0
  34. smriti_vendor/smriti/__pycache__/__init__.cpython-310.pyc +0 -0
  35. smriti_vendor/smriti/__pycache__/__main__.cpython-310.pyc +0 -0
  36. smriti_vendor/smriti/__pycache__/agent.cpython-310.pyc +0 -0
  37. smriti_vendor/smriti/__pycache__/api.cpython-310.pyc +0 -0
  38. smriti_vendor/smriti/__pycache__/backends.cpython-310.pyc +0 -0
  39. smriti_vendor/smriti/__pycache__/cli.cpython-310.pyc +0 -0
  40. smriti_vendor/smriti/__pycache__/config.cpython-310.pyc +0 -0
  41. smriti_vendor/smriti/__pycache__/core.cpython-310.pyc +0 -0
  42. smriti_vendor/smriti/__pycache__/gifp.cpython-310.pyc +0 -0
  43. smriti_vendor/smriti/__pycache__/identity_fingerprint.cpython-310.pyc +0 -0
  44. smriti_vendor/smriti/__pycache__/knowledge_graph.cpython-310.pyc +0 -0
  45. smriti_vendor/smriti/__pycache__/macp.cpython-310.pyc +0 -0
  46. smriti_vendor/smriti/__pycache__/mem_palace.cpython-310.pyc +0 -0
  47. smriti_vendor/smriti/__pycache__/semantic_memory.cpython-310.pyc +0 -0
  48. smriti_vendor/smriti/agent.py +262 -0
  49. smriti_vendor/smriti/api.py +538 -0
  50. smriti_vendor/smriti/backends.py +494 -0
README.md ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ library_name: smriti-ai
6
+ tags:
7
+ - ai-agent
8
+ - memory
9
+ - small-language-models
10
+ - inference-time-augmentation
11
+ - semantic-search
12
+ - knowledge-graph
13
+ - identity-continuity
14
+ - rag
15
+ pipeline_tag: text-generation
16
+ ---
17
+
18
+ # Smriti AI
19
+
20
+ ## What this is
21
+
22
+ Smriti AI is a memory-augmented inference layer for small language models. It adds external memory, semantic retrieval, knowledge-graph recall, identity continuity, and privacy-ready memory deletion without changing base model weights.
23
+
24
+ This repository layout is intended for a Hugging Face model-style deployment with a custom `handler.py`. The handler loads a base causal language model or calls a remote model endpoint, wraps it with Smriti AI memory, and returns model responses plus retrieved memories.
25
+
26
+ ## What this is not
27
+
28
+ Smriti AI is not a newly trained foundation model. It is not a fine-tuned model unless a separate fine-tuned checkpoint is explicitly included. It is an inference-time wrapper around a base language model.
29
+
30
+ Do not interpret this repository as a standalone model checkpoint. The base model is configured through `BASE_MODEL_ID` or `HF_ENDPOINT_URL`.
31
+
32
+ ## Research Lineage
33
+
34
+ Smriti AI follows four principles:
35
+
36
+ - **External memory**: conversational facts live outside model weights in a persistent, inspectable store.
37
+ - **Training-free recall**: relevant facts are retrieved and injected at inference time without fine-tuning the base model.
38
+ - **Identity continuity**: persona evidence is tracked as an embedding fingerprint so outputs can be checked for drift.
39
+ - **Small-model augmentation**: small causal language models can become more useful when paired with explicit memory and retrieval.
40
+
41
+ Historical GodelAI-Lite results were measured on an earlier system. Current Smriti AI results are measured separately and should not be conflated with historical results.
42
+
43
+ ## Architecture
44
+
45
+ ```text
46
+ User request
47
+ -> Smriti AI handler
48
+ -> memory retrieval
49
+ -> graph retrieval
50
+ -> identity context
51
+ -> base model inference
52
+ -> response
53
+ -> memory write/update
54
+ ```
55
+
56
+ The handler supports JSON, SQLite, Redis, and Postgres memory backends. For production, use Redis/Postgres or another external durable store. Do not store private user memory in the Hugging Face model repository.
57
+
58
+ ## Supported base models
59
+
60
+ Smriti AI is model-agnostic for Hugging Face causal language models.
61
+
62
+ Supported families depend on the installed `transformers` version and endpoint hardware:
63
+
64
+ - Gemma-style causal LMs when available, including the current benchmark path `google/gemma-4-E2B-it`.
65
+ - Llama/Phi/Mistral/Qwen-style causal LMs if supported by the runtime environment.
66
+ - Tiny CPU-safe local smoke-test models such as `sshleifer/tiny-gpt2` for handler validation only.
67
+
68
+ Tiny models are useful for endpoint plumbing tests. They are not public Smriti AI quality benchmarks.
69
+
70
+ ## Evaluation
71
+
72
+ Current local Gemma 4-only benchmark artifacts in the main Smriti AI repository report:
73
+
74
+ | Evaluation | Baseline Recall | Smriti AI Recall | Notes |
75
+ |---|---:|---:|---|
76
+ | Gemma-style three-fact protocol | 0/3 | 3/3 | Smriti AI recalls all injected facts after distractors. |
77
+ | Five-mode comparison | 0/3 | 3/3 | TF-IDF, Semantic, Semantic+Graph, and Semantic+Graph+Identity all recall 3/3 in the checked-in run. |
78
+ | Original broader protocol rerun | 0/3 | 3/3 | Overall average improves from 0.524 to 0.832 (`+58.9%`) in the current local Gemma 4 CPU rerun. |
79
+
80
+ Historical GodelAI-Lite results were measured on an earlier system. Current Smriti AI results are measured separately and should not be conflated with historical results.
81
+
82
+ ## Privacy
83
+
84
+ Smriti AI stores user memory. Treat it as user data.
85
+
86
+ - Memory can be encrypted by setting `SMRITI_ENCRYPTION_KEY`.
87
+ - `delete_memory` is supported by the handler.
88
+ - Production deployments should use external memory storage such as Redis/Postgres.
89
+ - Do not store private user memory in the Hugging Face model repository.
90
+ - Public/demo deployments should not receive real PII.
91
+
92
+ ## Limitations
93
+
94
+ - Retrieval quality depends on the quality and specificity of stored memory.
95
+ - Public/demo deployments should not receive real PII.
96
+ - Durable memory requires external backend or persistent endpoint storage.
97
+ - Latency depends on the base model, backend, retrieval mode, and endpoint hardware.
98
+ - A tiny CPU demo model validates handler plumbing but will not produce Gemma-quality answers.
99
+ - If no `BASE_MODEL_ID` or `HF_ENDPOINT_URL` is configured, the handler falls back to memory-only responses.
100
+
101
+ ## Environment variables
102
+
103
+ | Variable | Purpose |
104
+ |---|---|
105
+ | `BASE_MODEL_ID` | Hugging Face model ID to load inside the endpoint. |
106
+ | `HF_ENDPOINT_URL` | Optional remote model endpoint URL. If set, the handler calls this URL instead of loading a local base model. |
107
+ | `HF_TOKEN` | Token for gated/private base models or protected remote endpoints. |
108
+ | `SMRITI_MEMORY_BACKEND` | `json`, `sqlite`, `redis`, or `postgres`. |
109
+ | `SMRITI_MEMORY_PATH` | JSON user-memory directory or SQLite file path. |
110
+ | `REDIS_URL` | External Redis URL. Takes precedence when present. |
111
+ | `POSTGRES_DSN` | External Postgres DSN. Takes precedence when present and Redis is not configured. |
112
+ | `SMRITI_ENCRYPTION_KEY` | Memory encryption key. Do not commit it. |
113
+ | `SMRITI_RETRIEVAL_MODE` | `tfidf`, `semantic`, `semantic_graph`, or `semantic_graph_identity`. |
114
+ | `SMRITI_PUBLIC_DEMO` | `true` or `false`. Use `true` only for non-PII demos. |
115
+ | `SMRITI_MAX_MEMORY_ENTRIES` | Maximum retained entries per user/topic. |
116
+
117
+ ## How to call the endpoint
118
+
119
+ ### Chat / fact injection
120
+
121
+ ```json
122
+ {
123
+ "inputs": {
124
+ "operation": "chat",
125
+ "user_id": "customer-123",
126
+ "message": "My name is Alex and I am a marine biologist.",
127
+ "retrieval_mode": "semantic_graph_identity"
128
+ },
129
+ "parameters": {
130
+ "max_new_tokens": 256,
131
+ "temperature": 0.7,
132
+ "top_p": 0.9,
133
+ "return_memories": true
134
+ }
135
+ }
136
+ ```
137
+
138
+ ### Recall
139
+
140
+ ```json
141
+ {
142
+ "inputs": {
143
+ "operation": "chat",
144
+ "user_id": "customer-123",
145
+ "message": "What do you remember about me?",
146
+ "retrieval_mode": "semantic_graph_identity"
147
+ },
148
+ "parameters": {
149
+ "return_memories": true
150
+ }
151
+ }
152
+ ```
153
+
154
+ ### Delete memory
155
+
156
+ ```json
157
+ {
158
+ "inputs": {
159
+ "operation": "delete_memory",
160
+ "user_id": "customer-123"
161
+ }
162
+ }
163
+ ```
164
+
165
+ ### Health
166
+
167
+ ```json
168
+ {
169
+ "inputs": {
170
+ "operation": "health"
171
+ }
172
+ }
173
+ ```
174
+
175
+ ## Local test
176
+
177
+ ```bash
178
+ pip install -r requirements.txt
179
+ BASE_MODEL_ID=sshleifer/tiny-gpt2 \
180
+ SMRITI_MEMORY_BACKEND=json \
181
+ SMRITI_MEMORY_PATH=/tmp/smriti_hf_test.json \
182
+ python test_handler_local.py
183
+ ```
184
+
185
+ ## Custom-container deployment
186
+
187
+ If the standard Hugging Face handler is insufficient for your model size, CUDA libraries, Redis client policy, or enterprise network requirements, deploy the same files in a custom container. Use the main Smriti AI repository Dockerfiles as the starting point, install this handler, and expose a compatible HTTP API through Hugging Face Inference Endpoints custom container support.
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "project": "Smriti AI",
3
+ "base_model": "REPLACE_WITH_BASE_MODEL_ID",
4
+ "retrieval_mode": "semantic_graph_identity",
5
+ "memory_backend": "json",
6
+ "public_demo": false,
7
+ "max_memory_entries": 1000,
8
+ "enable_identity": true,
9
+ "enable_graph": true,
10
+ "enable_encryption": true
11
+ }
examples/request_delete.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "inputs": {
3
+ "operation": "delete_memory",
4
+ "user_id": "demo-user"
5
+ }
6
+ }
examples/request_distractor.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "inputs": {
3
+ "operation": "chat",
4
+ "user_id": "demo-user",
5
+ "message": "What is the capital of France?",
6
+ "retrieval_mode": "semantic_graph_identity"
7
+ }
8
+ }
examples/request_memory_inject.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "inputs": {
3
+ "operation": "chat",
4
+ "user_id": "demo-user",
5
+ "message": "My name is Alex and I am a marine biologist based in Hawaii.",
6
+ "retrieval_mode": "semantic_graph_identity"
7
+ },
8
+ "parameters": {
9
+ "return_memories": true
10
+ }
11
+ }
examples/request_recall.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "inputs": {
3
+ "operation": "chat",
4
+ "user_id": "demo-user",
5
+ "message": "What do you remember about me?",
6
+ "retrieval_mode": "semantic_graph_identity"
7
+ },
8
+ "parameters": {
9
+ "return_memories": true
10
+ }
11
+ }
handler.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face custom inference handler for Smriti AI.
2
+
3
+ This file is intentionally deployment glue. Core memory, retrieval, graph, and
4
+ identity behavior comes from the installed `smriti` package.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import logging
11
+ import os
12
+ import re
13
+ import sys
14
+ import time
15
+ import urllib.error
16
+ import urllib.request
17
+ from pathlib import Path
18
+ from threading import RLock
19
+ from typing import Any, Dict, List, Optional, Tuple
20
+
21
+ VENDOR_SRC = Path(__file__).resolve().parent / "smriti_vendor"
22
+ if VENDOR_SRC.exists() and str(VENDOR_SRC) not in sys.path:
23
+ sys.path.insert(0, str(VENDOR_SRC))
24
+
25
+ from smriti import IdentityFingerprint, MemPalaceLite, SmritiAILite # noqa: E402
26
+ from smriti.backends import ( # noqa: E402
27
+ JsonBackend,
28
+ MemoryBackend,
29
+ MemoryCipher,
30
+ PostgresBackend,
31
+ RedisBackend,
32
+ SqliteBackend,
33
+ )
34
+
35
+ LOGGER = logging.getLogger("smriti.hf_handler")
36
+ if not LOGGER.handlers:
37
+ logging.basicConfig(level=os.getenv("SMRITI_LOG_LEVEL", "INFO"))
38
+
39
+ DEFAULT_CONFIG = {
40
+ "project": "Smriti AI",
41
+ "base_model": "REPLACE_WITH_BASE_MODEL_ID",
42
+ "retrieval_mode": "semantic_graph_identity",
43
+ "memory_backend": "json",
44
+ "public_demo": False,
45
+ "max_memory_entries": 1000,
46
+ "enable_identity": True,
47
+ "enable_graph": True,
48
+ "enable_encryption": True,
49
+ }
50
+
51
+
52
+ class EndpointHandler:
53
+ """Hugging Face custom inference endpoint handler."""
54
+
55
+ def __init__(self, path: str = ""):
56
+ self.root = _resolve_root(path)
57
+ self.config = _load_config(self.root / "config.json")
58
+ self.lock = RLock()
59
+ self.memories: Dict[str, MemPalaceLite] = {}
60
+ self.identities: Dict[str, IdentityFingerprint] = {}
61
+ self.backend_warning: Optional[str] = None
62
+
63
+ self.base_model_id = _clean_model_id(
64
+ os.getenv("BASE_MODEL_ID") or self.config.get("base_model", "")
65
+ )
66
+ self.endpoint_url = os.getenv("HF_ENDPOINT_URL", "").strip()
67
+ self.hf_token = os.getenv("HF_TOKEN", "").strip()
68
+ self.default_retrieval_mode = os.getenv(
69
+ "SMRITI_RETRIEVAL_MODE",
70
+ str(self.config.get("retrieval_mode", "semantic_graph_identity")),
71
+ )
72
+ self.max_memory_entries = _int_env(
73
+ "SMRITI_MAX_MEMORY_ENTRIES",
74
+ int(self.config.get("max_memory_entries", 1000)),
75
+ )
76
+ self.public_demo = _bool_env("SMRITI_PUBLIC_DEMO", bool(self.config.get("public_demo", False)))
77
+ self.enable_graph_default = bool(self.config.get("enable_graph", True))
78
+ self.enable_identity_default = bool(self.config.get("enable_identity", True))
79
+ self.enable_encryption = bool(self.config.get("enable_encryption", True))
80
+
81
+ self.backend, self.backend_name = self._init_backend()
82
+ self.model = None
83
+ self.tokenizer = None
84
+ self.device = "cpu"
85
+ if self.endpoint_url:
86
+ LOGGER.info(
87
+ "Smriti AI handler using remote model endpoint; backend=%s retrieval=%s",
88
+ self.backend_name,
89
+ self.default_retrieval_mode,
90
+ )
91
+ elif self.base_model_id:
92
+ self._load_local_model(self.base_model_id)
93
+ else:
94
+ LOGGER.warning(
95
+ "No BASE_MODEL_ID or HF_ENDPOINT_URL configured; handler will run memory-only."
96
+ )
97
+
98
+ LOGGER.info(
99
+ "Smriti AI handler ready: base_model=%s remote_endpoint=%s backend=%s retrieval=%s encryption=%s public_demo=%s",
100
+ self.base_model_id or "memory-only",
101
+ bool(self.endpoint_url),
102
+ self.backend_name,
103
+ self.default_retrieval_mode,
104
+ self.enable_encryption and bool(os.getenv("SMRITI_ENCRYPTION_KEY") or os.getenv("SMRITI_MEMORY_KEY")),
105
+ self.public_demo,
106
+ )
107
+
108
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
109
+ start = time.perf_counter()
110
+ try:
111
+ inputs, parameters = _normalize_request(data)
112
+ operation = str(inputs.get("operation", "chat")).lower()
113
+
114
+ if operation == "health":
115
+ return self._health(start)
116
+ if operation == "delete_memory":
117
+ return self._delete_memory(inputs, start)
118
+ if operation != "chat":
119
+ return _error(f"Unsupported operation: {operation}", start)
120
+
121
+ return self._chat(inputs, parameters, start)
122
+ except Exception as exc: # Defensive boundary for endpoint runtimes.
123
+ LOGGER.exception("Unhandled Smriti AI handler error")
124
+ return _error(f"handler_error:{exc.__class__.__name__}: {exc}", start)
125
+
126
+ # ------------------------------------------------------------------
127
+ # Operation handlers
128
+ # ------------------------------------------------------------------
129
+
130
+ def _chat(
131
+ self,
132
+ inputs: Dict[str, Any],
133
+ parameters: Dict[str, Any],
134
+ start: float,
135
+ ) -> Dict[str, Any]:
136
+ user_id = str(inputs.get("user_id") or "").strip()
137
+ message = str(inputs.get("message") or "").strip()
138
+ topic_id = str(inputs.get("topic_id") or "general").strip() or "general"
139
+ if not user_id:
140
+ return _error("user_id is required", start)
141
+ if not message:
142
+ return _error("message is required for chat operation", start)
143
+
144
+ retrieval_mode = str(inputs.get("retrieval_mode") or self.default_retrieval_mode)
145
+ base_retrieval = _base_retrieval_mode(retrieval_mode)
146
+ include_graph = self.enable_graph_default and "graph" in retrieval_mode
147
+ identity_enabled = self.enable_identity_default and "identity" in retrieval_mode
148
+
149
+ with self.lock:
150
+ memory = self._get_memory(user_id, topic_id, base_retrieval)
151
+ context, retrieved_memories, graph_facts, retrieval_warning = self._retrieve_context(
152
+ memory,
153
+ user_id,
154
+ topic_id,
155
+ message,
156
+ include_graph,
157
+ )
158
+ identity = self._get_identity(user_id, identity_enabled)
159
+ agent = SmritiAILite(
160
+ model=self.model,
161
+ tokenizer=self.tokenizer,
162
+ retrieval_mode=base_retrieval,
163
+ session_id=user_id,
164
+ topic_id=topic_id,
165
+ memory=memory,
166
+ identity=identity,
167
+ auto_device=False,
168
+ )
169
+ agent.build_prompt = lambda user_input: _build_prompt(
170
+ agent,
171
+ memory,
172
+ user_id,
173
+ topic_id,
174
+ user_input,
175
+ include_graph,
176
+ identity_enabled,
177
+ )
178
+
179
+ generation_calls = 0
180
+
181
+ def generate(prompt: str, max_tokens: int = 256) -> str:
182
+ nonlocal generation_calls
183
+ generation_calls += 1
184
+ return self._generate_text(prompt, parameters, max_tokens=max_tokens)
185
+
186
+ agent._generate = generate # type: ignore[method-assign]
187
+ try:
188
+ response = agent.chat(message)
189
+ except Exception as exc:
190
+ LOGGER.exception("Model generation failed")
191
+ return _error(f"model_generation_failed:{exc.__class__.__name__}: {exc}", start)
192
+
193
+ response = _stabilize_recall_answer(message, response, retrieved_memories, graph_facts)
194
+ _replace_last_assistant_history(memory, response)
195
+
196
+ identity_check = agent.identity.evaluate_output(response) if identity_enabled else None
197
+ save_warning = self._save_memory(user_id, memory)
198
+
199
+ warnings = [item for item in [self.backend_warning, retrieval_warning, save_warning] if item]
200
+ return {
201
+ "response": response,
202
+ "retrieved_memories": retrieved_memories,
203
+ "graph_facts": graph_facts,
204
+ "identity": {
205
+ "enabled": identity_enabled,
206
+ "drift_score": float(identity_check.distance) if identity_check else 0.0,
207
+ "refinement_triggered": generation_calls > 1,
208
+ },
209
+ "latency_ms": round((time.perf_counter() - start) * 1000, 3),
210
+ "backend": self.backend_name,
211
+ "retrieval_mode": retrieval_mode,
212
+ "warnings": warnings,
213
+ }
214
+
215
+ def _delete_memory(self, inputs: Dict[str, Any], start: float) -> Dict[str, Any]:
216
+ user_id = str(inputs.get("user_id") or "").strip()
217
+ if not user_id:
218
+ return _error("user_id is required for delete_memory operation", start)
219
+ with self.lock:
220
+ existed_cache = self.memories.pop(user_id, None) is not None
221
+ self.identities.pop(user_id, None)
222
+ try:
223
+ deleted_backend = self.backend.delete_user(user_id)
224
+ except Exception as exc:
225
+ LOGGER.exception("Memory backend delete failed")
226
+ return _error(f"backend_delete_failed:{exc.__class__.__name__}: {exc}", start)
227
+ return {
228
+ "deleted": bool(existed_cache or deleted_backend),
229
+ "user_id": user_id,
230
+ "latency_ms": round((time.perf_counter() - start) * 1000, 3),
231
+ "backend": self.backend_name,
232
+ }
233
+
234
+ def _health(self, start: float) -> Dict[str, Any]:
235
+ return {
236
+ "status": "ok",
237
+ "project": "Smriti AI",
238
+ "base_model": self.base_model_id or ("remote-endpoint" if self.endpoint_url else "memory-only"),
239
+ "backend": self.backend_name,
240
+ "retrieval_mode": self.default_retrieval_mode,
241
+ "latency_ms": round((time.perf_counter() - start) * 1000, 3),
242
+ }
243
+
244
+ # ------------------------------------------------------------------
245
+ # Runtime setup
246
+ # ------------------------------------------------------------------
247
+
248
+ def _init_backend(self) -> Tuple[MemoryBackend, str]:
249
+ encryption_key = os.getenv("SMRITI_ENCRYPTION_KEY") or os.getenv("SMRITI_MEMORY_KEY")
250
+ if encryption_key:
251
+ os.environ["SMRITI_MEMORY_KEY"] = encryption_key
252
+ cipher = MemoryCipher(encryption_key if self.enable_encryption else None)
253
+
254
+ redis_url = os.getenv("REDIS_URL") or os.getenv("SMRITI_REDIS_URL")
255
+ postgres_dsn = os.getenv("POSTGRES_DSN") or os.getenv("SMRITI_POSTGRES_DSN")
256
+ selected = (os.getenv("SMRITI_MEMORY_BACKEND") or self.config.get("memory_backend") or "json").lower()
257
+ memory_path = os.getenv("SMRITI_MEMORY_PATH", "/tmp/smriti_hf_memory")
258
+
259
+ if redis_url:
260
+ return RedisBackend(url=redis_url, cipher=cipher), "redis"
261
+ if postgres_dsn:
262
+ return PostgresBackend(dsn=postgres_dsn, cipher=cipher), "postgres"
263
+ if selected == "redis":
264
+ return RedisBackend(url=redis_url or "redis://localhost:6379/0", cipher=cipher), "redis"
265
+ if selected in {"postgres", "postgresql"}:
266
+ return PostgresBackend(dsn=postgres_dsn or "", cipher=cipher), "postgres"
267
+ if selected == "sqlite":
268
+ return SqliteBackend(path=memory_path, cipher=cipher), "sqlite"
269
+ return JsonBackend(root=_json_root(memory_path), cipher=cipher), "json"
270
+
271
+ def _load_local_model(self, model_id: str) -> None:
272
+ try:
273
+ import torch
274
+ from transformers import AutoModelForCausalLM, AutoTokenizer
275
+ except Exception as exc:
276
+ raise RuntimeError("Install torch and transformers to load a local base model.") from exc
277
+
278
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
279
+ dtype = torch.float32
280
+ if self.device == "cuda":
281
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
282
+
283
+ kwargs = {"token": self.hf_token} if self.hf_token else {}
284
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)
285
+ if getattr(self.tokenizer, "pad_token_id", None) is None:
286
+ self.tokenizer.pad_token = self.tokenizer.eos_token
287
+ try:
288
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype, **kwargs)
289
+ except TypeError:
290
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **kwargs)
291
+ self.model.to(self.device)
292
+ self.model.eval()
293
+ LOGGER.info("Loaded local base model %s on %s", model_id, self.device)
294
+
295
+ # ------------------------------------------------------------------
296
+ # Memory and generation helpers
297
+ # ------------------------------------------------------------------
298
+
299
+ def _get_memory(self, user_id: str, topic_id: str, retrieval_mode: str) -> MemPalaceLite:
300
+ self.backend_warning = None
301
+ if user_id not in self.memories:
302
+ state = None
303
+ try:
304
+ state = self.backend.load(user_id)
305
+ except Exception as exc:
306
+ LOGGER.exception("Memory backend load failed; starting empty memory")
307
+ self.backend_warning = f"backend_load_failed:{exc.__class__.__name__}"
308
+ if state:
309
+ memory = MemPalaceLite.from_dict(
310
+ state,
311
+ retrieval_mode=retrieval_mode,
312
+ session_id=user_id,
313
+ topic_id=topic_id,
314
+ max_facts=self.max_memory_entries,
315
+ max_entries_per_topic=self.max_memory_entries,
316
+ )
317
+ else:
318
+ memory = MemPalaceLite(
319
+ retrieval_mode=retrieval_mode,
320
+ session_id=user_id,
321
+ topic_id=topic_id,
322
+ max_facts=self.max_memory_entries,
323
+ max_entries_per_topic=self.max_memory_entries,
324
+ )
325
+ self.memories[user_id] = memory
326
+ memory = self.memories[user_id]
327
+ if memory.retrieval_mode != retrieval_mode:
328
+ memory = MemPalaceLite.from_dict(
329
+ memory.to_dict(),
330
+ retrieval_mode=retrieval_mode,
331
+ session_id=user_id,
332
+ topic_id=topic_id,
333
+ max_facts=self.max_memory_entries,
334
+ max_entries_per_topic=self.max_memory_entries,
335
+ )
336
+ self.memories[user_id] = memory
337
+ memory.session_id = user_id
338
+ memory.topic_id = topic_id
339
+ return memory
340
+
341
+ def _get_identity(self, user_id: str, enabled: bool) -> IdentityFingerprint:
342
+ if user_id not in self.identities:
343
+ threshold = 0.35 if enabled else 2.0
344
+ self.identities[user_id] = IdentityFingerprint(
345
+ role="helpful AI assistant with persistent memory",
346
+ threshold=threshold,
347
+ )
348
+ identity = self.identities[user_id]
349
+ if not enabled:
350
+ identity.threshold = 2.0
351
+ return identity
352
+
353
+ def _retrieve_context(
354
+ self,
355
+ memory: MemPalaceLite,
356
+ user_id: str,
357
+ topic_id: str,
358
+ message: str,
359
+ include_graph: bool,
360
+ ) -> Tuple[str, List[str], List[str], Optional[str]]:
361
+ try:
362
+ context = memory.get_context(
363
+ query=message,
364
+ session_id=user_id,
365
+ topic_id=topic_id,
366
+ include_graph=include_graph,
367
+ )
368
+ retrieved_memories = memory.retrieve_facts(
369
+ message,
370
+ k=5,
371
+ session_id=user_id,
372
+ topic_id=topic_id,
373
+ )
374
+ graph_facts = _section_bullets(context, "[RELATED GRAPH FACTS]") if include_graph else []
375
+ return context, retrieved_memories, graph_facts, None
376
+ except Exception as exc:
377
+ LOGGER.exception("Memory retrieval failed")
378
+ return "", [], [], f"retrieval_failed:{exc.__class__.__name__}"
379
+
380
+ def _save_memory(self, user_id: str, memory: MemPalaceLite) -> Optional[str]:
381
+ try:
382
+ self.backend.save(user_id, memory.to_dict())
383
+ return None
384
+ except Exception as exc:
385
+ LOGGER.exception("Memory backend save failed")
386
+ return f"backend_save_failed:{exc.__class__.__name__}"
387
+
388
+ def _generate_text(self, prompt: str, parameters: Dict[str, Any], max_tokens: int = 256) -> str:
389
+ max_new_tokens = int(parameters.get("max_new_tokens", max_tokens) or max_tokens)
390
+ temperature = float(parameters.get("temperature", 0.7))
391
+ top_p = float(parameters.get("top_p", 0.9))
392
+ if self.endpoint_url:
393
+ return self._generate_remote(prompt, max_new_tokens, temperature, top_p)
394
+ if self.model is not None and self.tokenizer is not None:
395
+ return self._generate_local(prompt, max_new_tokens, temperature, top_p)
396
+ return _memory_only_answer(prompt)
397
+
398
+ def _generate_local(
399
+ self,
400
+ prompt: str,
401
+ max_new_tokens: int,
402
+ temperature: float,
403
+ top_p: float,
404
+ ) -> str:
405
+ import torch
406
+
407
+ messages = [{"role": "user", "content": prompt}]
408
+ try:
409
+ formatted = self.tokenizer.apply_chat_template(
410
+ messages,
411
+ tokenize=False,
412
+ add_generation_prompt=True,
413
+ )
414
+ except Exception:
415
+ formatted = prompt
416
+ inputs = self.tokenizer(
417
+ formatted,
418
+ return_tensors="pt",
419
+ truncation=True,
420
+ max_length=2048,
421
+ )
422
+ inputs = {key: value.to(self.device) for key, value in inputs.items()}
423
+ generate_kwargs = {
424
+ "max_new_tokens": max_new_tokens,
425
+ "do_sample": temperature > 0,
426
+ "pad_token_id": getattr(self.tokenizer, "eos_token_id", None),
427
+ }
428
+ if temperature > 0:
429
+ generate_kwargs["temperature"] = temperature
430
+ generate_kwargs["top_p"] = top_p
431
+ with torch.inference_mode():
432
+ output = self.model.generate(**inputs, **generate_kwargs)
433
+ return self.tokenizer.decode(
434
+ output[0, inputs["input_ids"].shape[1] :].detach().cpu(),
435
+ skip_special_tokens=True,
436
+ ).strip()
437
+
438
+ def _generate_remote(
439
+ self,
440
+ prompt: str,
441
+ max_new_tokens: int,
442
+ temperature: float,
443
+ top_p: float,
444
+ ) -> str:
445
+ payload = {
446
+ "inputs": prompt,
447
+ "parameters": {
448
+ "max_new_tokens": max_new_tokens,
449
+ "temperature": temperature,
450
+ "top_p": top_p,
451
+ },
452
+ }
453
+ headers = {"Content-Type": "application/json"}
454
+ if self.hf_token:
455
+ headers["Authorization"] = f"Bearer {self.hf_token}"
456
+ request = urllib.request.Request(
457
+ self.endpoint_url,
458
+ data=json.dumps(payload).encode("utf-8"),
459
+ headers=headers,
460
+ method="POST",
461
+ )
462
+ try:
463
+ with urllib.request.urlopen(request, timeout=120) as response: # noqa: S310
464
+ raw = response.read().decode("utf-8")
465
+ except urllib.error.HTTPError as exc:
466
+ body = exc.read().decode("utf-8", errors="replace")
467
+ raise RuntimeError(f"remote endpoint HTTP {exc.code}: {body[:300]}") from exc
468
+ parsed = json.loads(raw)
469
+ return _extract_generated_text(parsed)
470
+
471
+
472
+ # ----------------------------------------------------------------------
473
+ # Request, context, and formatting helpers
474
+ # ----------------------------------------------------------------------
475
+
476
+
477
+ def _resolve_root(path: str) -> Path:
478
+ if path:
479
+ root = Path(path).resolve()
480
+ return root.parent if root.is_file() else root
481
+ return Path(__file__).resolve().parent
482
+
483
+
484
+ def _load_config(path: Path) -> Dict[str, Any]:
485
+ if not path.exists():
486
+ return dict(DEFAULT_CONFIG)
487
+ data = json.loads(path.read_text(encoding="utf-8"))
488
+ config = dict(DEFAULT_CONFIG)
489
+ config.update(data)
490
+ return config
491
+
492
+
493
+ def _normalize_request(data: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
494
+ if not isinstance(data, dict):
495
+ raise ValueError("Request body must be a JSON object.")
496
+ if "inputs" in data:
497
+ inputs = data.get("inputs") or {}
498
+ if isinstance(inputs, str):
499
+ inputs = {"message": inputs}
500
+ parameters = data.get("parameters") or {}
501
+ else:
502
+ inputs = data
503
+ parameters = data.get("parameters") or {}
504
+ if not isinstance(inputs, dict) or not isinstance(parameters, dict):
505
+ raise ValueError("inputs and parameters must be JSON objects.")
506
+ return inputs, parameters
507
+
508
+
509
+ def _base_retrieval_mode(mode: str) -> str:
510
+ return "tfidf" if str(mode).lower().startswith("tfidf") else "semantic"
511
+
512
+
513
+ def _build_prompt(
514
+ agent: SmritiAILite,
515
+ memory: MemPalaceLite,
516
+ user_id: str,
517
+ topic_id: str,
518
+ user_input: str,
519
+ include_graph: bool,
520
+ identity_enabled: bool,
521
+ ) -> str:
522
+ identity = agent.identity.get_identity_prompt() if identity_enabled else ""
523
+ context = memory.get_context(
524
+ query=user_input,
525
+ session_id=user_id,
526
+ topic_id=topic_id,
527
+ include_graph=include_graph,
528
+ )
529
+ parts = [part for part in [identity.strip(), context.strip(), user_input.strip()] if part]
530
+ return "\n\n".join(parts)
531
+
532
+
533
+ def _section_bullets(context: str, heading: str) -> List[str]:
534
+ if heading not in context:
535
+ return []
536
+ after = context.split(heading, 1)[1]
537
+ chunks = re.split(r"\n\[[A-Z ]+\]", after, maxsplit=1)
538
+ section = chunks[0]
539
+ bullets = []
540
+ for line in section.splitlines():
541
+ cleaned = line.strip()
542
+ if cleaned.startswith("*"):
543
+ bullets.append(cleaned.lstrip("* ").strip())
544
+ return bullets
545
+
546
+
547
+ def _memory_only_answer(prompt: str) -> str:
548
+ facts = _section_bullets(prompt, "[REMEMBERED FACTS]")
549
+ graph = _section_bullets(prompt, "[RELATED GRAPH FACTS]")
550
+ combined = facts + [item for item in graph if item not in facts]
551
+ if combined:
552
+ return "I remember: " + "; ".join(combined[:5])
553
+ return "Memory updated. No prior relevant context was found."
554
+
555
+
556
+ def _is_recall_query(message: str) -> bool:
557
+ lowered = message.lower()
558
+ return any(
559
+ phrase in lowered
560
+ for phrase in [
561
+ "remember",
562
+ "what do you know about me",
563
+ "who am i",
564
+ "where do i work",
565
+ "what is my name",
566
+ "what do i do",
567
+ ]
568
+ )
569
+
570
+
571
+ def _stabilize_recall_answer(
572
+ message: str,
573
+ response: str,
574
+ retrieved_memories: List[str],
575
+ graph_facts: List[str],
576
+ ) -> str:
577
+ if not _is_recall_query(message):
578
+ return response
579
+ combined = retrieved_memories + [item for item in graph_facts if item not in retrieved_memories]
580
+ if not combined:
581
+ return response
582
+ if _mentions_memory_terms(response, combined):
583
+ return response
584
+ return "I remember: " + "; ".join(combined[:5])
585
+
586
+
587
+ def _mentions_memory_terms(response: str, memories: List[str]) -> bool:
588
+ response_terms = set(re.findall(r"[a-z0-9']{4,}", response.lower()))
589
+ memory_terms = set()
590
+ for memory in memories:
591
+ memory_terms.update(re.findall(r"[a-z0-9']{4,}", memory.lower()))
592
+ return bool(response_terms & memory_terms)
593
+
594
+
595
+ def _replace_last_assistant_history(memory: MemPalaceLite, response: str) -> None:
596
+ if memory.history and memory.history[-1].category == "assistant_output":
597
+ memory.history[-1].content = "Assistant: " + response[:200]
598
+
599
+
600
+ def _extract_generated_text(parsed: Any) -> str:
601
+ if isinstance(parsed, list) and parsed:
602
+ return _extract_generated_text(parsed[0])
603
+ if isinstance(parsed, dict):
604
+ for key in ["generated_text", "response", "text", "output"]:
605
+ value = parsed.get(key)
606
+ if isinstance(value, str):
607
+ return value.strip()
608
+ if "outputs" in parsed:
609
+ return _extract_generated_text(parsed["outputs"])
610
+ if isinstance(parsed, str):
611
+ return parsed.strip()
612
+ raise RuntimeError("Remote endpoint did not return generated text.")
613
+
614
+
615
+ def _json_root(memory_path: str) -> Path:
616
+ path = Path(memory_path)
617
+ if path.suffix.lower() in {".json", ".jsonl"}:
618
+ return path.with_suffix("")
619
+ return path
620
+
621
+
622
+ def _clean_model_id(value: str) -> str:
623
+ value = (value or "").strip()
624
+ if not value or value == "REPLACE_WITH_BASE_MODEL_ID":
625
+ return ""
626
+ return value
627
+
628
+
629
+ def _bool_env(name: str, default: bool) -> bool:
630
+ raw = os.getenv(name)
631
+ if raw is None:
632
+ return default
633
+ return raw.strip().lower() in {"1", "true", "yes", "on"}
634
+
635
+
636
+ def _int_env(name: str, default: int) -> int:
637
+ try:
638
+ return int(os.getenv(name, str(default)))
639
+ except ValueError:
640
+ return default
641
+
642
+
643
+ def _error(message: str, start: float) -> Dict[str, Any]:
644
+ return {
645
+ "error": message,
646
+ "latency_ms": round((time.perf_counter() - start) * 1000, 3),
647
+ }
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Smriti AI is not yet assumed to be published on PyPI for this deployment artifact.
2
+ # Until it is published, install the package directly from the GitHub repository.
3
+ git+https://github.com/Luciferai04/smriti-ai.git
4
+
5
+ # After PyPI publication, replace the GitHub line above with:
6
+ # smriti-ai>=0.3.1
7
+
8
+ transformers
9
+ accelerate
10
+ torch
11
+ sentence-transformers
12
+ faiss-cpu
13
+ networkx
14
+ cryptography
15
+ pydantic
16
+ redis
17
+ psycopg2-binary
18
+ huggingface_hub
19
+ requests
smriti_endpoint_config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Smriti AI Hugging Face Inference Endpoint configuration template.
2
+ # Values here are documentation defaults. Set real values as endpoint environment
3
+ # variables or managed secrets, not as committed plaintext.
4
+
5
+ BASE_MODEL_ID: ""
6
+ HF_ENDPOINT_URL: ""
7
+ HF_TOKEN: ""
8
+ SMRITI_MEMORY_BACKEND: "json"
9
+ SMRITI_MEMORY_PATH: "/data/smriti_memory"
10
+ REDIS_URL: ""
11
+ POSTGRES_DSN: ""
12
+ SMRITI_ENCRYPTION_KEY: ""
13
+ SMRITI_RETRIEVAL_MODE: "semantic_graph_identity"
14
+ SMRITI_PUBLIC_DEMO: "false"
15
+ SMRITI_MAX_MEMORY_ENTRIES: "1000"
16
+
17
+ warnings:
18
+ - Do not commit HF_TOKEN.
19
+ - Do not commit SMRITI_ENCRYPTION_KEY.
20
+ - Production memory should use Redis/Postgres or another external durable storage service.
21
+ - The Hugging Face model repository should not contain user memory files.
22
+ - Public demo endpoints should not receive real PII.
23
+
24
+ variables:
25
+ BASE_MODEL_ID: Hugging Face model ID to load locally inside the endpoint.
26
+ HF_ENDPOINT_URL: Optional remote model endpoint URL. If set, Smriti calls it instead of loading BASE_MODEL_ID locally.
27
+ HF_TOKEN: Hugging Face token for gated/private base models or protected remote endpoints.
28
+ SMRITI_MEMORY_BACKEND: json | sqlite | redis | postgres.
29
+ SMRITI_MEMORY_PATH: Path for JSON user-memory directory or SQLite database file.
30
+ REDIS_URL: External Redis URL. Takes precedence when present.
31
+ POSTGRES_DSN: External Postgres DSN. Takes precedence when present and REDIS_URL is empty.
32
+ SMRITI_ENCRYPTION_KEY: Encryption key for user memory. Maps to Smriti's SMRITI_MEMORY_KEY.
33
+ SMRITI_RETRIEVAL_MODE: tfidf | semantic | semantic_graph | semantic_graph_identity.
34
+ SMRITI_PUBLIC_DEMO: true | false. Use true only for non-PII demos.
35
+ SMRITI_MAX_MEMORY_ENTRIES: Maximum fact entries retained per user/topic.
smriti_vendor/mempalace/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Backward-compatible imports for the renamed :mod:`smriti` package."""
2
+
3
+ from smriti import * # noqa: F401,F403
smriti_vendor/mempalace/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (224 Bytes). View file
 
smriti_vendor/mempalace/__pycache__/agent.cpython-310.pyc ADDED
Binary file (207 Bytes). View file
 
smriti_vendor/mempalace/__pycache__/api.cpython-310.pyc ADDED
Binary file (201 Bytes). View file
 
smriti_vendor/mempalace/__pycache__/cli.cpython-310.pyc ADDED
Binary file (201 Bytes). View file
 
smriti_vendor/mempalace/__pycache__/core.cpython-310.pyc ADDED
Binary file (204 Bytes). View file
 
smriti_vendor/mempalace/__pycache__/gifp.cpython-310.pyc ADDED
Binary file (204 Bytes). View file
 
smriti_vendor/mempalace/__pycache__/identity_fingerprint.cpython-310.pyc ADDED
Binary file (252 Bytes). View file
 
smriti_vendor/mempalace/__pycache__/knowledge_graph.cpython-310.pyc ADDED
Binary file (237 Bytes). View file
 
smriti_vendor/mempalace/__pycache__/macp.cpython-310.pyc ADDED
Binary file (204 Bytes). View file
 
smriti_vendor/mempalace/__pycache__/mem_palace.cpython-310.pyc ADDED
Binary file (222 Bytes). View file
 
smriti_vendor/mempalace/__pycache__/semantic_memory.cpython-310.pyc ADDED
Binary file (237 Bytes). View file
 
smriti_vendor/mempalace/agent.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Compatibility wrapper for :mod:`smriti.agent`."""
2
+
3
+ from smriti.agent import * # noqa: F401,F403
smriti_vendor/mempalace/api.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Compatibility wrapper for :mod:`smriti.api`."""
2
+
3
+ from smriti.api import * # noqa: F401,F403
smriti_vendor/mempalace/cli.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Compatibility wrapper for :mod:`smriti.cli`."""
2
+
3
+ from smriti.cli import * # noqa: F401,F403
smriti_vendor/mempalace/core.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Compatibility wrapper for :mod:`smriti.core`."""
2
+
3
+ from smriti.core import * # noqa: F401,F403
smriti_vendor/mempalace/gifp.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Compatibility wrapper for :mod:`smriti.gifp`."""
2
+
3
+ from smriti.gifp import * # noqa: F401,F403
smriti_vendor/mempalace/identity_fingerprint.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Compatibility wrapper for :mod:`smriti.identity_fingerprint`."""
2
+
3
+ from smriti.identity_fingerprint import * # noqa: F401,F403
smriti_vendor/mempalace/knowledge_graph.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Compatibility wrapper for :mod:`smriti.knowledge_graph`."""
2
+
3
+ from smriti.knowledge_graph import * # noqa: F401,F403
smriti_vendor/mempalace/macp.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Compatibility wrapper for :mod:`smriti.macp`."""
2
+
3
+ from smriti.macp import * # noqa: F401,F403
smriti_vendor/mempalace/mem_palace.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Compatibility wrapper for :mod:`smriti.mem_palace`."""
2
+
3
+ from smriti.mem_palace import * # noqa: F401,F403
smriti_vendor/mempalace/semantic_memory.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Compatibility wrapper for :mod:`smriti.semantic_memory`."""
2
+
3
+ from smriti.semantic_memory import * # noqa: F401,F403
smriti_vendor/smriti/__init__.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Smriti AI — Inference-time memory framework for small language models.
3
+
4
+ Smriti AI adds semantic memory, reasoning continuity, and identity
5
+ governance to any HuggingFace causal LM with zero fine-tuning. The name
6
+ comes from smriti, a Sanskrit term associated with memory and remembrance.
7
+
8
+ Features:
9
+ - Semantic memory with FAISS-based retrieval
10
+ - Knowledge graph integration
11
+ - Embedding-based identity governance (GIFP v1.0)
12
+ - Multi-user support via API and CLI
13
+
14
+ Quick start:
15
+ from smriti import MemPalaceLite, SmritiAILite
16
+
17
+ memory = MemPalaceLite(retrieval_mode="semantic")
18
+ agent = SmritiAILite(model=model, tokenizer=tokenizer)
19
+ reply = agent.chat("My name is Jordan and I am a marine biologist.")
20
+ """
21
+
22
+ from .agent import BaselineGemma, GodelAILite, SmritiAILite
23
+ from .backends import (
24
+ JsonBackend,
25
+ MemoryBackend,
26
+ MemoryCipher,
27
+ PostgresBackend,
28
+ RedisBackend,
29
+ SqliteBackend,
30
+ build_backend,
31
+ )
32
+ from .config import SmritiConfig, configure_environment_from_file, load_config, write_default_config
33
+ from .core import MemoryEntry, MemPalaceLite
34
+ from .gifp import GIFPLite
35
+ from .macp import MACPLite, ReasoningStep
36
+
37
+ # New modules for enhanced functionality
38
+ try:
39
+ from .semantic_memory import (
40
+ RetrievalResult,
41
+ SemanticMemory,
42
+ MemoryEntry as SemanticMemoryEntry,
43
+ )
44
+ except ImportError:
45
+ RetrievalResult = None
46
+ SemanticMemory = None
47
+ SemanticMemoryEntry = None
48
+
49
+ try:
50
+ from .knowledge_graph import GraphTriple, KnowledgeGraphMemory
51
+ except ImportError:
52
+ GraphTriple = None
53
+ KnowledgeGraphMemory = None
54
+
55
+ try:
56
+ from .identity_fingerprint import IdentityCheck, IdentityFingerprint
57
+ except ImportError:
58
+ IdentityCheck = None
59
+ IdentityFingerprint = None
60
+
61
+ __version__ = "0.3.1"
62
+ __author__ = "Alton Lee Wei Bin (creator35lwb)"
63
+
64
+ __all__ = [
65
+ "MemoryEntry",
66
+ "MemPalaceLite",
67
+ "ReasoningStep",
68
+ "MACPLite",
69
+ "GIFPLite",
70
+ "SmritiAILite",
71
+ "GodelAILite",
72
+ "BaselineGemma",
73
+ "MemoryBackend",
74
+ "MemoryCipher",
75
+ "JsonBackend",
76
+ "SqliteBackend",
77
+ "RedisBackend",
78
+ "PostgresBackend",
79
+ "build_backend",
80
+ "SmritiConfig",
81
+ "load_config",
82
+ "configure_environment_from_file",
83
+ "write_default_config",
84
+ ]
85
+
86
+ # Add new classes if available
87
+ if SemanticMemory is not None:
88
+ __all__.extend(["SemanticMemory", "SemanticMemoryEntry", "RetrievalResult"])
89
+ if KnowledgeGraphMemory is not None:
90
+ __all__.extend(["KnowledgeGraphMemory", "GraphTriple"])
91
+ if IdentityFingerprint is not None:
92
+ __all__.extend(["IdentityFingerprint", "IdentityCheck"])
93
+ __all__.extend(["api_app", "create_app", "get_memory", "set_agent_factory", "set_memory_backend", "cli_main"])
94
+
95
+
96
+ def __getattr__(name: str):
97
+ """Lazy optional API/CLI exports without double-registering Prometheus metrics."""
98
+
99
+ if name in {"api_app", "create_app", "get_memory", "set_agent_factory", "set_memory_backend"}:
100
+ from .api import app as api_app
101
+ from .api import create_app, get_memory, set_agent_factory, set_memory_backend
102
+
103
+ values = {
104
+ "api_app": api_app,
105
+ "create_app": create_app,
106
+ "get_memory": get_memory,
107
+ "set_agent_factory": set_agent_factory,
108
+ "set_memory_backend": set_memory_backend,
109
+ }
110
+ return values[name]
111
+ if name == "cli_main":
112
+ from .cli import main as cli_main
113
+
114
+ return cli_main
115
+ raise AttributeError(name)
smriti_vendor/smriti/__main__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Run the Smriti AI CLI with `python -m smriti`."""
2
+
3
+ from .cli import main
4
+
5
+
6
+ if __name__ == "__main__":
7
+ raise SystemExit(main())
smriti_vendor/smriti/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.83 kB). View file
 
smriti_vendor/smriti/__pycache__/__main__.cpython-310.pyc ADDED
Binary file (266 Bytes). View file
 
smriti_vendor/smriti/__pycache__/agent.cpython-310.pyc ADDED
Binary file (8.03 kB). View file
 
smriti_vendor/smriti/__pycache__/api.cpython-310.pyc ADDED
Binary file (16.1 kB). View file
 
smriti_vendor/smriti/__pycache__/backends.cpython-310.pyc ADDED
Binary file (20.5 kB). View file
 
smriti_vendor/smriti/__pycache__/cli.cpython-310.pyc ADDED
Binary file (8.95 kB). View file
 
smriti_vendor/smriti/__pycache__/config.cpython-310.pyc ADDED
Binary file (5.37 kB). View file
 
smriti_vendor/smriti/__pycache__/core.cpython-310.pyc ADDED
Binary file (13.9 kB). View file
 
smriti_vendor/smriti/__pycache__/gifp.cpython-310.pyc ADDED
Binary file (510 Bytes). View file
 
smriti_vendor/smriti/__pycache__/identity_fingerprint.cpython-310.pyc ADDED
Binary file (9.65 kB). View file
 
smriti_vendor/smriti/__pycache__/knowledge_graph.cpython-310.pyc ADDED
Binary file (13.3 kB). View file
 
smriti_vendor/smriti/__pycache__/macp.cpython-310.pyc ADDED
Binary file (2.13 kB). View file
 
smriti_vendor/smriti/__pycache__/mem_palace.cpython-310.pyc ADDED
Binary file (289 Bytes). View file
 
smriti_vendor/smriti/__pycache__/semantic_memory.cpython-310.pyc ADDED
Binary file (17.3 kB). View file
 
smriti_vendor/smriti/agent.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from contextlib import nullcontext
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ from .core import MemPalaceLite
6
+ from .identity_fingerprint import IdentityFingerprint
7
+ from .macp import MACPLite
8
+
9
+ try:
10
+ import torch
11
+ except Exception:
12
+ torch = None
13
+
14
+ try:
15
+ from transformers import GenerationConfig
16
+ except Exception:
17
+ GenerationConfig = None
18
+
19
+
20
+ class SmritiAILite:
21
+ """
22
+ Model-agnostic SLM wrapper with semantic memory, graph memory, reasoning
23
+ continuity, and GIFP v1.0 identity governance.
24
+
25
+ Pass any pre-loaded HuggingFace causal LM and tokenizer.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ model: Any,
31
+ tokenizer: Any,
32
+ memory_path: Optional[str] = None,
33
+ retrieval_mode: str = "semantic",
34
+ session_id: str = "default",
35
+ topic_id: str = "general",
36
+ memory: Optional[MemPalaceLite] = None,
37
+ identity: Optional[IdentityFingerprint] = None,
38
+ auto_device: bool = True,
39
+ ):
40
+ self.model = model
41
+ self.tokenizer = tokenizer
42
+ self.session_id = session_id
43
+ self.topic_id = topic_id
44
+
45
+ if memory is not None:
46
+ self.memory = memory
47
+ elif memory_path and os.path.exists(memory_path):
48
+ self.memory = MemPalaceLite.load(memory_path, retrieval_mode=retrieval_mode)
49
+ else:
50
+ self.memory = MemPalaceLite(
51
+ retrieval_mode=retrieval_mode,
52
+ session_id=session_id,
53
+ topic_id=topic_id,
54
+ )
55
+
56
+ self.continuity = MACPLite()
57
+ self.identity = identity or IdentityFingerprint(
58
+ role="helpful AI assistant with persistent memory"
59
+ )
60
+ self.identity.set_constraints(
61
+ [
62
+ "Always be helpful and accurate",
63
+ "Reference previous context when relevant",
64
+ "Maintain logical consistency across turns",
65
+ "Acknowledge uncertainty when present",
66
+ ]
67
+ )
68
+ self.device, self.autocast_dtype = configure_inference_device()
69
+ if auto_device:
70
+ self._move_model_to_best_device()
71
+
72
+ def build_prompt(self, user_input: str) -> str:
73
+ identity = self.identity.get_identity_prompt()
74
+ ctx = self.memory.get_context(
75
+ query=user_input,
76
+ session_id=self.session_id,
77
+ topic_id=self.topic_id,
78
+ )
79
+ if ctx:
80
+ return identity + "\n" + ctx + "\n\n" + user_input
81
+ return identity + "\n" + user_input
82
+
83
+ def _generate(self, prompt: str, max_tokens: int = 256) -> str:
84
+ if torch is None or GenerationConfig is None:
85
+ raise RuntimeError("torch and transformers are required for model generation.")
86
+
87
+ messages = [{"role": "user", "content": prompt}]
88
+ try:
89
+ formatted = self.tokenizer.apply_chat_template(
90
+ messages, tokenize=False, add_generation_prompt=True
91
+ )
92
+ except Exception:
93
+ formatted = prompt
94
+
95
+ inputs = self.tokenizer(
96
+ formatted,
97
+ return_tensors="pt",
98
+ truncation=True,
99
+ max_length=2048,
100
+ )
101
+ model_device = _model_device(self.model) or self.device
102
+ inputs = {key: value.to(model_device) for key, value in inputs.items()}
103
+ cfg = GenerationConfig(
104
+ max_new_tokens=max_tokens,
105
+ temperature=0.7,
106
+ top_p=0.9,
107
+ do_sample=True,
108
+ pad_token_id=getattr(self.tokenizer, "eos_token_id", None),
109
+ )
110
+ with torch.inference_mode(), _autocast_context(model_device, self.autocast_dtype):
111
+ out = self.model.generate(**inputs, generation_config=cfg)
112
+ return self.tokenizer.decode(
113
+ out[0, inputs["input_ids"].shape[1] :].detach().cpu(),
114
+ skip_special_tokens=True,
115
+ ).strip()
116
+
117
+ def chat(self, user_input: str, refine: bool = False) -> str:
118
+ self.continuity.start_chain(user_input)
119
+ self.identity.observe_user_input(user_input)
120
+ prompt = self.build_prompt(user_input)
121
+ response = self._generate(prompt)
122
+
123
+ context = self.memory.get_context(
124
+ query=user_input,
125
+ session_id=self.session_id,
126
+ topic_id=self.topic_id,
127
+ )
128
+ response, identity_check = self.identity.ensure_aligned(
129
+ response,
130
+ self._generate,
131
+ user_input=user_input,
132
+ context=context,
133
+ )
134
+ if refine and identity_check.consistency_score < 0.5:
135
+ response = self.identity.refinement_pass(
136
+ self._generate,
137
+ response,
138
+ user_input=user_input,
139
+ context=context,
140
+ )
141
+ identity_check = self.identity.evaluate_output(response)
142
+
143
+ self.continuity.add_step(
144
+ user_input,
145
+ response,
146
+ identity_check.consistency_score,
147
+ "continue" if identity_check.consistency_score > 0.7 else "refine",
148
+ )
149
+ for fact in self.memory.extract_facts(response, user_input=user_input):
150
+ self.memory.add_fact(fact, session_id=self.session_id, topic_id=self.topic_id)
151
+ self.memory.add_to_history("User: " + user_input, "user_input")
152
+ self.memory.add_to_history(
153
+ "Assistant: " + response[:200], "assistant_output"
154
+ )
155
+ self.identity.record_behavior(response)
156
+ return response
157
+
158
+ def save_memory(self, path: str):
159
+ self.memory.save(path)
160
+
161
+ def load_memory(self, path: str):
162
+ self.memory = MemPalaceLite.load(path, retrieval_mode=self.memory.retrieval_mode)
163
+
164
+ def get_memory_state(self) -> Dict:
165
+ return self.memory.to_dict()
166
+
167
+ def get_reasoning_chain(self) -> str:
168
+ return self.continuity.get_chain_summary()
169
+
170
+ def _move_model_to_best_device(self) -> None:
171
+ if torch is None or self.device is None or str(self.device) == "cpu":
172
+ return
173
+ try:
174
+ current = _model_device(self.model)
175
+ if current is not None and str(current).startswith("cuda"):
176
+ return
177
+ self.model.to(self.device)
178
+ except Exception:
179
+ pass
180
+
181
+
182
+ class BaselineGemma:
183
+ """Plain causal LM with no memory, no identity layer, no continuity."""
184
+
185
+ def __init__(self, model: Any, tokenizer: Any, auto_device: bool = True):
186
+ self.model = model
187
+ self.tokenizer = tokenizer
188
+ self._history: List[str] = []
189
+ self.device, self.autocast_dtype = configure_inference_device()
190
+ if auto_device and torch is not None and str(self.device) != "cpu":
191
+ try:
192
+ self.model.to(self.device)
193
+ except Exception:
194
+ pass
195
+
196
+ def chat(self, user_input: str) -> str:
197
+ if torch is None or GenerationConfig is None:
198
+ raise RuntimeError("torch and transformers are required for model generation.")
199
+
200
+ messages = [{"role": "user", "content": user_input}]
201
+ try:
202
+ prompt = self.tokenizer.apply_chat_template(
203
+ messages, tokenize=False, add_generation_prompt=True
204
+ )
205
+ except Exception:
206
+ prompt = user_input
207
+ inputs = self.tokenizer(
208
+ prompt,
209
+ return_tensors="pt",
210
+ truncation=True,
211
+ max_length=2048,
212
+ )
213
+ model_device = _model_device(self.model) or self.device
214
+ inputs = {key: value.to(model_device) for key, value in inputs.items()}
215
+ cfg = GenerationConfig(
216
+ max_new_tokens=getattr(self, "max_new_tokens", 256),
217
+ temperature=0.7,
218
+ top_p=0.9,
219
+ do_sample=True,
220
+ pad_token_id=getattr(self.tokenizer, "eos_token_id", None),
221
+ )
222
+ with torch.inference_mode(), _autocast_context(model_device, self.autocast_dtype):
223
+ out = self.model.generate(**inputs, generation_config=cfg)
224
+ response = self.tokenizer.decode(
225
+ out[0, inputs["input_ids"].shape[1] :].detach().cpu(),
226
+ skip_special_tokens=True,
227
+ ).strip()
228
+ self._history.extend(["User: " + user_input, "Assistant: " + response])
229
+ return response
230
+
231
+ def reset(self):
232
+ self._history = []
233
+
234
+
235
+ def configure_inference_device() -> Tuple[Any, Any]:
236
+ """Return the preferred torch device and mixed-precision dtype."""
237
+
238
+ if torch is None:
239
+ return "cpu", None
240
+ if torch.cuda.is_available():
241
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
242
+ return torch.device("cuda"), dtype
243
+ return torch.device("cpu"), torch.float32
244
+
245
+
246
+ def _model_device(model: Any) -> Any:
247
+ try:
248
+ return next(model.parameters()).device
249
+ except Exception:
250
+ return None
251
+
252
+
253
+ def _autocast_context(device: Any, dtype: Any):
254
+ if torch is None or dtype is None:
255
+ return nullcontext()
256
+ if str(device).startswith("cuda"):
257
+ return torch.autocast(device_type="cuda", dtype=dtype)
258
+ return nullcontext()
259
+
260
+
261
+ # Backwards compatibility for existing user code.
262
+ GodelAILite = SmritiAILite
smriti_vendor/smriti/api.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI layer for multi-user, multi-agent Smriti AI memory access."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ import time
9
+ import uuid
10
+ from contextlib import contextmanager
11
+ from threading import RLock
12
+ from typing import Any, Callable, Dict, Iterator, List, Optional
13
+
14
+ from fastapi import FastAPI, HTTPException, Request, Response
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from pydantic import BaseModel, Field
17
+ from prometheus_client import CONTENT_TYPE_LATEST, Counter, Gauge, Histogram, generate_latest
18
+
19
+ from .backends import MemoryBackend, build_backend
20
+ from .config import configure_environment_from_file, load_config
21
+ from .core import MemPalaceLite
22
+
23
+
24
+ AgentFactory = Callable[..., Any]
25
+
26
+ USER_MEMORIES: Dict[str, MemPalaceLite] = {}
27
+ MEMORY_LOCK = RLock()
28
+ MEMORY_BACKEND: Optional[MemoryBackend] = None
29
+ AGENT_FACTORY: Optional[AgentFactory] = None
30
+ LOGGER = logging.getLogger("smriti.api")
31
+ LOGGER.setLevel(logging.INFO)
32
+ if not LOGGER.handlers:
33
+ _handler = logging.StreamHandler()
34
+ _handler.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s"))
35
+ LOGGER.addHandler(_handler)
36
+ LOGGER.propagate = False
37
+
38
+ HTTP_REQUESTS = Counter(
39
+ "smriti_http_requests_total",
40
+ "Total HTTP requests handled by the Smriti AI API.",
41
+ ("method", "path", "status"),
42
+ )
43
+ HTTP_ERRORS = Counter(
44
+ "smriti_http_errors_total",
45
+ "Total HTTP requests that completed with status >= 500.",
46
+ ("method", "path"),
47
+ )
48
+ HTTP_LATENCY = Histogram(
49
+ "smriti_http_request_latency_seconds",
50
+ "End-to-end HTTP request latency.",
51
+ ("method", "path"),
52
+ )
53
+ RETRIEVAL_LATENCY = Histogram(
54
+ "smriti_retrieval_latency_seconds",
55
+ "Memory retrieval latency for chat requests.",
56
+ ("retrieval_mode",),
57
+ )
58
+ TOKEN_USAGE = Counter(
59
+ "smriti_tokens_total",
60
+ "Approximate whitespace-token count observed by the API.",
61
+ ("user_id", "agent_id"),
62
+ )
63
+ USER_MEMORY_COUNT = Gauge(
64
+ "smriti_user_memories",
65
+ "Number of in-memory user memory stores.",
66
+ )
67
+ USER_MEMORY_BYTES = Gauge(
68
+ "smriti_user_memory_bytes",
69
+ "Approximate serialized memory size by user.",
70
+ ("user_id",),
71
+ )
72
+
73
+
74
+ class ChatRequest(BaseModel):
75
+ user_id: str
76
+ message: str
77
+ topic_id: str = "general"
78
+ agent_id: str = "executor"
79
+ retrieval_mode: str = "semantic"
80
+
81
+
82
+ class ChatResponse(BaseModel):
83
+ user_id: str
84
+ agent_id: str
85
+ topic_id: str
86
+ response: str
87
+ retrieved_context: str
88
+ memory: Dict[str, Any]
89
+
90
+
91
+ class MemoryLoadRequest(BaseModel):
92
+ user_id: str
93
+ memory: Optional[Dict[str, Any]] = None
94
+ path: Optional[str] = None
95
+ retrieval_mode: str = "semantic"
96
+
97
+
98
+ class MemorySaveRequest(BaseModel):
99
+ user_id: str
100
+ path: Optional[str] = None
101
+
102
+
103
+ class MemoryDeleteRequest(BaseModel):
104
+ user_id: str
105
+ path: Optional[str] = None
106
+
107
+
108
+ class GraphQueryRequest(BaseModel):
109
+ user_id: str
110
+ query_entity: str
111
+ topic_id: Optional[str] = None
112
+ depth: int = Field(default=1, ge=1, le=4)
113
+
114
+
115
+ def set_agent_factory(factory: Optional[AgentFactory]) -> None:
116
+ """
117
+ Register a callable that returns a configured model agent.
118
+
119
+ The callable receives `user_id`, `memory`, `topic_id`, and `agent_id`.
120
+ When no factory is configured, `/chat` runs in memory-only mode.
121
+ """
122
+
123
+ global AGENT_FACTORY
124
+ AGENT_FACTORY = factory
125
+
126
+
127
+ def set_memory_backend(backend: Optional[MemoryBackend]) -> None:
128
+ """Override the configured persistence backend for tests or deployments."""
129
+
130
+ global MEMORY_BACKEND
131
+ MEMORY_BACKEND = backend
132
+
133
+
134
+ def get_memory_backend() -> MemoryBackend:
135
+ """Return the configured durable backend, constructing it lazily from env."""
136
+
137
+ global MEMORY_BACKEND
138
+ if MEMORY_BACKEND is None:
139
+ configure_environment_from_file()
140
+ MEMORY_BACKEND = build_backend()
141
+ return MEMORY_BACKEND
142
+
143
+
144
+ def get_memory(user_id: str, retrieval_mode: str = "semantic") -> MemPalaceLite:
145
+ with MEMORY_LOCK:
146
+ if user_id not in USER_MEMORIES:
147
+ state = None
148
+ try:
149
+ state = get_memory_backend().load(user_id)
150
+ except Exception:
151
+ LOGGER.exception("Durable memory load failed; starting empty memory")
152
+ if state:
153
+ memory = MemPalaceLite.from_dict(state, retrieval_mode=retrieval_mode)
154
+ memory.session_id = user_id
155
+ else:
156
+ memory = MemPalaceLite(
157
+ retrieval_mode=retrieval_mode,
158
+ session_id=user_id,
159
+ )
160
+ USER_MEMORIES[user_id] = memory
161
+ USER_MEMORY_COUNT.set(len(USER_MEMORIES))
162
+ return USER_MEMORIES[user_id]
163
+
164
+
165
+ def create_app() -> FastAPI:
166
+ config = configure_environment_from_file()
167
+ app = FastAPI(
168
+ title="Smriti AI API",
169
+ version="0.3.1",
170
+ description="Semantic memory, knowledge graph and identity governance API.",
171
+ )
172
+ app.add_middleware(
173
+ CORSMiddleware,
174
+ allow_origins=config.cors_origins,
175
+ allow_credentials=True,
176
+ allow_methods=["*"],
177
+ allow_headers=["*"],
178
+ )
179
+
180
+ @app.middleware("http")
181
+ async def request_observability(request: Request, call_next: Callable[..., Any]) -> Response:
182
+ request_id = request.headers.get("x-request-id", str(uuid.uuid4()))
183
+ path = request.url.path
184
+ start = time.perf_counter()
185
+ status_code = 500
186
+ try:
187
+ _enforce_api_key(request)
188
+ response = await call_next(request)
189
+ status_code = response.status_code
190
+ except HTTPException as exc:
191
+ status_code = exc.status_code
192
+ response = Response(
193
+ content=f'{{"detail":"{exc.detail}"}}',
194
+ status_code=exc.status_code,
195
+ media_type="application/json",
196
+ )
197
+ except Exception:
198
+ LOGGER.exception(
199
+ "Unhandled API request failure",
200
+ extra={"request_id": request_id, "path": path},
201
+ )
202
+ response = Response(
203
+ content='{"detail":"Internal server error"}',
204
+ status_code=500,
205
+ media_type="application/json",
206
+ )
207
+ duration = time.perf_counter() - start
208
+ HTTP_LATENCY.labels(request.method, path).observe(duration)
209
+ HTTP_REQUESTS.labels(request.method, path, str(status_code)).inc()
210
+ if status_code >= 500:
211
+ HTTP_ERRORS.labels(request.method, path).inc()
212
+ USER_MEMORY_COUNT.set(len(USER_MEMORIES))
213
+ response.headers["x-request-id"] = request_id
214
+ LOGGER.info(
215
+ "request completed request_id=%s method=%s path=%s status=%s duration_s=%.6f",
216
+ request_id,
217
+ request.method,
218
+ path,
219
+ status_code,
220
+ duration,
221
+ )
222
+ return response
223
+
224
+ @app.get("/health")
225
+ def health() -> Dict[str, Any]:
226
+ return {"status": "ok", "users": len(USER_MEMORIES)}
227
+
228
+ @app.get("/metrics")
229
+ def metrics() -> Response:
230
+ return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
231
+
232
+ @app.post("/chat", response_model=ChatResponse)
233
+ def chat(request: ChatRequest) -> ChatResponse:
234
+ memory = get_memory(request.user_id, retrieval_mode=request.retrieval_mode)
235
+ memory.retrieval_mode = request.retrieval_mode
236
+ memory.topic_id = request.topic_id
237
+ context, degraded, warnings = _safe_get_context(
238
+ memory,
239
+ query=request.message,
240
+ session_id=request.user_id,
241
+ topic_id=request.topic_id,
242
+ retrieval_mode=request.retrieval_mode,
243
+ )
244
+
245
+ if AGENT_FACTORY is not None:
246
+ agent = _build_agent(
247
+ AGENT_FACTORY,
248
+ user_id=request.user_id,
249
+ memory=memory,
250
+ topic_id=request.topic_id,
251
+ agent_id=request.agent_id,
252
+ )
253
+ with MEMORY_LOCK:
254
+ try:
255
+ response = agent.chat(request.message)
256
+ except Exception as exc:
257
+ LOGGER.exception("Agent factory chat failed")
258
+ degraded = True
259
+ warnings.append(f"agent_failure:{exc.__class__.__name__}")
260
+ response = _memory_only_response(context)
261
+ state = memory.to_dict()
262
+ else:
263
+ response = _memory_only_response(context)
264
+ with MEMORY_LOCK:
265
+ _safe_update_memory(
266
+ memory,
267
+ request.message,
268
+ response,
269
+ request.user_id,
270
+ request.topic_id,
271
+ warnings,
272
+ )
273
+ state = memory.to_dict()
274
+ _persist_if_configured(request.user_id, state, warnings)
275
+
276
+ TOKEN_USAGE.labels(request.user_id, request.agent_id).inc(
277
+ _count_tokens(request.message) + _count_tokens(response)
278
+ )
279
+ state["_degraded"] = degraded
280
+ state["_warnings"] = warnings
281
+ return ChatResponse(
282
+ user_id=request.user_id,
283
+ agent_id=request.agent_id,
284
+ topic_id=request.topic_id,
285
+ response=response,
286
+ retrieved_context=context,
287
+ memory=state,
288
+ )
289
+
290
+ @app.post("/memory/load")
291
+ def load_memory(request: MemoryLoadRequest) -> Dict[str, Any]:
292
+ with MEMORY_LOCK:
293
+ if request.path:
294
+ memory = MemPalaceLite.load(
295
+ request.path,
296
+ retrieval_mode=request.retrieval_mode,
297
+ )
298
+ elif request.memory:
299
+ memory = MemPalaceLite.from_dict(
300
+ request.memory or {},
301
+ retrieval_mode=request.retrieval_mode,
302
+ )
303
+ else:
304
+ state = get_memory_backend().load(request.user_id)
305
+ if state is None:
306
+ raise HTTPException(status_code=404, detail="No memory found for user.")
307
+ memory = MemPalaceLite.from_dict(
308
+ state,
309
+ retrieval_mode=request.retrieval_mode,
310
+ )
311
+ memory.session_id = request.user_id
312
+ USER_MEMORIES[request.user_id] = memory
313
+ return memory.to_dict()
314
+
315
+ @app.post("/memory/save")
316
+ def save_memory(request: MemorySaveRequest) -> Dict[str, Any]:
317
+ memory = get_memory(request.user_id)
318
+ with MEMORY_LOCK:
319
+ if request.path:
320
+ memory.save(request.path)
321
+ state = memory.to_dict()
322
+ if not request.path:
323
+ get_memory_backend().save(request.user_id, state)
324
+ _observe_memory_size(request.user_id, state)
325
+ return state
326
+
327
+ @app.post("/memory/delete")
328
+ def delete_memory(request: MemoryDeleteRequest) -> Dict[str, Any]:
329
+ with MEMORY_LOCK:
330
+ existed = USER_MEMORIES.pop(request.user_id, None) is not None
331
+ USER_MEMORY_COUNT.set(len(USER_MEMORIES))
332
+ deleted_file = False
333
+ if request.path and os.path.exists(request.path):
334
+ os.remove(request.path)
335
+ deleted_file = True
336
+ deleted_backend = False
337
+ try:
338
+ deleted_backend = get_memory_backend().delete_user(request.user_id)
339
+ except Exception:
340
+ LOGGER.exception("Durable memory deletion failed")
341
+ try:
342
+ USER_MEMORY_BYTES.remove(request.user_id)
343
+ except Exception:
344
+ pass
345
+ return {
346
+ "user_id": request.user_id,
347
+ "deleted_memory": existed,
348
+ "deleted_file": deleted_file,
349
+ "deleted_backend": deleted_backend,
350
+ "remaining_users": len(USER_MEMORIES),
351
+ }
352
+
353
+ @app.post("/graph/query")
354
+ def graph_query(request: GraphQueryRequest) -> Dict[str, Any]:
355
+ memory = get_memory(request.user_id)
356
+ try:
357
+ triples = memory.knowledge_graph.query_graph(
358
+ request.user_id,
359
+ request.query_entity,
360
+ depth=request.depth,
361
+ topic_id=request.topic_id,
362
+ )
363
+ degraded = False
364
+ warnings: List[str] = []
365
+ except Exception as exc:
366
+ LOGGER.exception("Knowledge graph query failed")
367
+ triples = []
368
+ degraded = True
369
+ warnings = [f"knowledge_graph_failure:{exc.__class__.__name__}"]
370
+ return {
371
+ "user_id": request.user_id,
372
+ "query_entity": request.query_entity,
373
+ "triples": [triple.__dict__ for triple in triples],
374
+ "facts": memory.knowledge_graph.triples_to_text(triples),
375
+ "degraded": degraded,
376
+ "warnings": warnings,
377
+ }
378
+
379
+ return app
380
+
381
+
382
+ def _build_agent(factory: AgentFactory, **kwargs: Any) -> Any:
383
+ try:
384
+ return factory(**kwargs)
385
+ except TypeError:
386
+ return factory(kwargs["memory"])
387
+
388
+
389
+ def _memory_only_response(context: str) -> str:
390
+ if context:
391
+ bullets = []
392
+ for line in context.splitlines():
393
+ cleaned = line.strip()
394
+ if cleaned.startswith("* "):
395
+ bullets.append(cleaned[2:].strip())
396
+ if bullets:
397
+ rendered = "\n".join(f"- {fact}" for fact in bullets[:5])
398
+ return f"Memory updated. I found relevant context:\n{rendered}"
399
+ return "Memory updated. Relevant context is available for the configured model."
400
+ return "Memory updated. No prior relevant context was found."
401
+
402
+
403
+ def _safe_get_context(
404
+ memory: MemPalaceLite,
405
+ query: str,
406
+ session_id: str,
407
+ topic_id: str,
408
+ retrieval_mode: str,
409
+ ) -> tuple[str, bool, List[str]]:
410
+ warnings: List[str] = []
411
+ with _observe_retrieval(retrieval_mode):
412
+ try:
413
+ return (
414
+ memory.get_context(
415
+ query=query,
416
+ session_id=session_id,
417
+ topic_id=topic_id,
418
+ ),
419
+ False,
420
+ warnings,
421
+ )
422
+ except Exception as exc:
423
+ LOGGER.exception("Primary retrieval failed; degrading to TF-IDF/no-graph context")
424
+ warnings.append(f"primary_retrieval_failure:{exc.__class__.__name__}")
425
+
426
+ original_mode = memory.retrieval_mode
427
+ try:
428
+ memory.retrieval_mode = "tfidf"
429
+ with _observe_retrieval("tfidf"):
430
+ context = memory.get_context(
431
+ query=query,
432
+ session_id=session_id,
433
+ topic_id=topic_id,
434
+ include_graph=False,
435
+ )
436
+ warnings.append("degraded_to_tfidf")
437
+ return context, True, warnings
438
+ except Exception as exc:
439
+ LOGGER.exception("Fallback TF-IDF retrieval failed")
440
+ warnings.append(f"fallback_retrieval_failure:{exc.__class__.__name__}")
441
+ return "", True, warnings
442
+ finally:
443
+ memory.retrieval_mode = original_mode
444
+
445
+
446
+ def _safe_update_memory(
447
+ memory: MemPalaceLite,
448
+ user_message: str,
449
+ response: str,
450
+ session_id: str,
451
+ topic_id: str,
452
+ warnings: List[str],
453
+ ) -> None:
454
+ try:
455
+ for fact in memory.extract_facts("", user_input=user_message):
456
+ try:
457
+ memory.add_fact(fact, session_id=session_id, topic_id=topic_id)
458
+ except Exception as exc:
459
+ LOGGER.exception("Fact storage failed")
460
+ warnings.append(f"fact_storage_failure:{exc.__class__.__name__}")
461
+ memory.add_to_history("User: " + user_message, "user_input")
462
+ memory.add_to_history("Assistant: " + response[:200], "assistant_output")
463
+ except Exception as exc:
464
+ LOGGER.exception("Memory update failed")
465
+ warnings.append(f"memory_update_failure:{exc.__class__.__name__}")
466
+
467
+
468
+ def _persist_if_configured(user_id: str, state: Dict[str, Any], warnings: List[str]) -> None:
469
+ if os.getenv("SMRITI_AUTOSAVE", "0").lower() not in {"1", "true", "yes"}:
470
+ _observe_memory_size(user_id, state)
471
+ return
472
+ try:
473
+ get_memory_backend().save(user_id, state)
474
+ _observe_memory_size(user_id, state)
475
+ except Exception as exc:
476
+ LOGGER.exception("Durable memory autosave failed")
477
+ warnings.append(f"autosave_failure:{exc.__class__.__name__}")
478
+
479
+
480
+ def _observe_memory_size(user_id: str, state: Dict[str, Any]) -> None:
481
+ try:
482
+ USER_MEMORY_BYTES.labels(user_id).set(len(json.dumps(state)))
483
+ except Exception:
484
+ pass
485
+
486
+
487
+ @contextmanager
488
+ def _observe_retrieval(retrieval_mode: str) -> Iterator[None]:
489
+ start = time.perf_counter()
490
+ try:
491
+ yield
492
+ finally:
493
+ RETRIEVAL_LATENCY.labels(retrieval_mode).observe(time.perf_counter() - start)
494
+
495
+
496
+ def _count_tokens(text: str) -> int:
497
+ return max(1, len(text.split())) if text else 0
498
+
499
+
500
+ def _enforce_api_key(request: Request) -> None:
501
+ expected = os.getenv("SMRITI_API_KEY")
502
+ if not expected:
503
+ return
504
+ if request.url.path in {"/health", "/metrics", "/docs", "/openapi.json"}:
505
+ return
506
+ supplied = request.headers.get("x-api-key")
507
+ if supplied != expected:
508
+ raise HTTPException(status_code=401, detail="Invalid or missing API key.")
509
+
510
+
511
+ app = create_app()
512
+
513
+
514
+ def main(argv: Optional[List[str]] = None) -> None:
515
+ """Run the API with `python -m smriti.api` or the `smriti-api` entry point."""
516
+
517
+ import argparse
518
+ import uvicorn
519
+
520
+ parser = argparse.ArgumentParser(description="Run the Smriti AI FastAPI service.")
521
+ parser.add_argument("--config", help="Path to config.yaml.")
522
+ parser.add_argument("--host", help="Bind host. Defaults to config or SMRITI_HOST.")
523
+ parser.add_argument("--port", type=int, help="Bind port. Defaults to config or SMRITI_PORT.")
524
+ parser.add_argument("--reload", action="store_true", help="Enable Uvicorn reload mode.")
525
+ args = parser.parse_args(argv)
526
+ if args.config:
527
+ os.environ["SMRITI_CONFIG_PATH"] = args.config
528
+ config = load_config()
529
+ uvicorn.run(
530
+ "smriti.api:app" if args.reload else app,
531
+ host=args.host or os.getenv("SMRITI_HOST", config.host),
532
+ port=args.port or int(os.getenv("SMRITI_PORT", config.port)),
533
+ reload=args.reload,
534
+ )
535
+
536
+
537
+ if __name__ == "__main__":
538
+ main()
smriti_vendor/smriti/backends.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Durable memory backends for Smriti AI.
2
+
3
+ Backends persist complete user memory blobs and also expose a minimal entry API
4
+ for tools that want to store/retrieve lightweight facts without instantiating the
5
+ full runtime. Optional encryption is applied at the blob boundary so JSON, SQL,
6
+ Redis, and Postgres stores share the same privacy behavior.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import base64
12
+ import hashlib
13
+ import json
14
+ import os
15
+ import re
16
+ import sqlite3
17
+ import time
18
+ from abc import ABC, abstractmethod
19
+ from pathlib import Path
20
+ from typing import Any, Dict, List, Optional
21
+
22
+
23
+ class MemoryBackend(ABC):
24
+ """Abstract persistence contract for user-isolated Smriti AI memory."""
25
+
26
+ @abstractmethod
27
+ def load(self, user_id: str) -> Optional[Dict[str, Any]]:
28
+ """Load a complete memory state for one user, or None if absent."""
29
+
30
+ @abstractmethod
31
+ def save(self, user_id: str, memory: Dict[str, Any]) -> None:
32
+ """Persist a complete memory state for one user."""
33
+
34
+ @abstractmethod
35
+ def add_entry(
36
+ self,
37
+ user_id: str,
38
+ session_id: str,
39
+ topic_id: str,
40
+ text: str,
41
+ metadata: Optional[Dict[str, Any]] = None,
42
+ ) -> None:
43
+ """Persist one lightweight fact/entry for a user/session/topic."""
44
+
45
+ @abstractmethod
46
+ def retrieve(
47
+ self,
48
+ user_id: str,
49
+ session_id: Optional[str] = None,
50
+ topic_id: Optional[str] = None,
51
+ query: str = "",
52
+ k: int = 5,
53
+ ) -> List[Dict[str, Any]]:
54
+ """Retrieve lightweight entries scoped to a user/session/topic."""
55
+
56
+ @abstractmethod
57
+ def delete_user(self, user_id: str) -> bool:
58
+ """Delete all memory owned by one user. Return whether anything existed."""
59
+
60
+
61
+ class MemoryCipher:
62
+ """Optional symmetric encryption wrapper using Fernet when configured."""
63
+
64
+ def __init__(self, secret: Optional[str] = None):
65
+ self.secret = secret or os.getenv("SMRITI_MEMORY_KEY")
66
+ self._fernet = None
67
+ if self.secret:
68
+ try:
69
+ from cryptography.fernet import Fernet
70
+ except Exception as exc: # pragma: no cover - depends on optional install.
71
+ raise RuntimeError(
72
+ "SMRITI_MEMORY_KEY is set, but cryptography is not installed. "
73
+ "Install smriti-ai[security] or smriti-ai[full]."
74
+ ) from exc
75
+ self._fernet = Fernet(_fernet_key(self.secret))
76
+
77
+ @property
78
+ def enabled(self) -> bool:
79
+ return self._fernet is not None
80
+
81
+ def wrap(self, payload: Dict[str, Any]) -> Dict[str, Any]:
82
+ if not self._fernet:
83
+ return {"encrypted": False, "payload": payload}
84
+ raw = json.dumps(payload, sort_keys=True).encode("utf-8")
85
+ return {
86
+ "encrypted": True,
87
+ "algorithm": "fernet-sha256-derived-key",
88
+ "payload": self._fernet.encrypt(raw).decode("utf-8"),
89
+ }
90
+
91
+ def unwrap(self, wrapper: Dict[str, Any]) -> Dict[str, Any]:
92
+ if not wrapper.get("encrypted"):
93
+ return dict(wrapper.get("payload", {}))
94
+ if not self._fernet:
95
+ raise RuntimeError("Memory blob is encrypted but SMRITI_MEMORY_KEY is not configured.")
96
+ decrypted = self._fernet.decrypt(wrapper["payload"].encode("utf-8"))
97
+ return json.loads(decrypted.decode("utf-8"))
98
+
99
+
100
+ def build_backend(kind: Optional[str] = None, **kwargs: Any) -> MemoryBackend:
101
+ """Construct a backend from an explicit kind or SMRITI_MEMORY_BACKEND."""
102
+
103
+ selected = (kind or os.getenv("SMRITI_MEMORY_BACKEND") or "json").lower()
104
+ if selected == "json":
105
+ return JsonBackend(root=kwargs.get("root") or os.getenv("SMRITI_MEMORY_DIR", "data/memory"))
106
+ if selected == "sqlite":
107
+ return SqliteBackend(path=kwargs.get("path") or os.getenv("SMRITI_SQLITE_PATH", "data/smriti_memory.sqlite3"))
108
+ if selected == "redis":
109
+ return RedisBackend(url=kwargs.get("url") or os.getenv("SMRITI_REDIS_URL", "redis://localhost:6379/0"))
110
+ if selected in {"postgres", "postgresql"}:
111
+ return PostgresBackend(dsn=kwargs.get("dsn") or os.getenv("SMRITI_POSTGRES_DSN", ""))
112
+ raise ValueError("SMRITI_MEMORY_BACKEND must be one of: json, sqlite, redis, postgres.")
113
+
114
+
115
+ class JsonBackend(MemoryBackend):
116
+ """File-per-user JSON backend. This preserves the original local behavior."""
117
+
118
+ def __init__(self, root: str | Path = "data/memory", cipher: Optional[MemoryCipher] = None):
119
+ self.root = Path(root)
120
+ self.cipher = cipher or MemoryCipher()
121
+
122
+ def load(self, user_id: str) -> Optional[Dict[str, Any]]:
123
+ path = self._path(user_id)
124
+ if not path.exists():
125
+ return None
126
+ return self.cipher.unwrap(json.loads(path.read_text(encoding="utf-8")))
127
+
128
+ def save(self, user_id: str, memory: Dict[str, Any]) -> None:
129
+ self.root.mkdir(parents=True, exist_ok=True)
130
+ self._path(user_id).write_text(
131
+ json.dumps(self.cipher.wrap(memory), indent=2),
132
+ encoding="utf-8",
133
+ )
134
+
135
+ def add_entry(
136
+ self,
137
+ user_id: str,
138
+ session_id: str,
139
+ topic_id: str,
140
+ text: str,
141
+ metadata: Optional[Dict[str, Any]] = None,
142
+ ) -> None:
143
+ state = self.load(user_id) or {"backend_entries": []}
144
+ state.setdefault("backend_entries", []).append(_entry(session_id, topic_id, text, metadata))
145
+ self.save(user_id, state)
146
+
147
+ def retrieve(
148
+ self,
149
+ user_id: str,
150
+ session_id: Optional[str] = None,
151
+ topic_id: Optional[str] = None,
152
+ query: str = "",
153
+ k: int = 5,
154
+ ) -> List[Dict[str, Any]]:
155
+ state = self.load(user_id) or {}
156
+ return _rank_entries(state.get("backend_entries", []), session_id, topic_id, query, k)
157
+
158
+ def delete_user(self, user_id: str) -> bool:
159
+ path = self._path(user_id)
160
+ existed = path.exists()
161
+ if existed:
162
+ path.unlink()
163
+ return existed
164
+
165
+ def _path(self, user_id: str) -> Path:
166
+ return self.root / f"{_safe_id(user_id)}.json"
167
+
168
+
169
+ class SqliteBackend(MemoryBackend):
170
+ """SQLite backend for local durable multi-user memory."""
171
+
172
+ def __init__(self, path: str | Path = "data/smriti_memory.sqlite3", cipher: Optional[MemoryCipher] = None):
173
+ self.path = Path(path)
174
+ self.cipher = cipher or MemoryCipher()
175
+ self._init_schema()
176
+
177
+ def load(self, user_id: str) -> Optional[Dict[str, Any]]:
178
+ with self._connect() as conn:
179
+ row = conn.execute("SELECT payload FROM user_memory WHERE user_id = ?", (user_id,)).fetchone()
180
+ if not row:
181
+ return None
182
+ return self.cipher.unwrap(json.loads(row[0]))
183
+
184
+ def save(self, user_id: str, memory: Dict[str, Any]) -> None:
185
+ payload = json.dumps(self.cipher.wrap(memory))
186
+ with self._connect() as conn:
187
+ conn.execute(
188
+ """
189
+ INSERT INTO user_memory(user_id, payload, updated_at)
190
+ VALUES(?, ?, ?)
191
+ ON CONFLICT(user_id) DO UPDATE SET payload=excluded.payload, updated_at=excluded.updated_at
192
+ """,
193
+ (user_id, payload, time.time()),
194
+ )
195
+
196
+ def add_entry(
197
+ self,
198
+ user_id: str,
199
+ session_id: str,
200
+ topic_id: str,
201
+ text: str,
202
+ metadata: Optional[Dict[str, Any]] = None,
203
+ ) -> None:
204
+ with self._connect() as conn:
205
+ conn.execute(
206
+ """
207
+ INSERT INTO memory_entries(user_id, session_id, topic_id, text, metadata, created_at)
208
+ VALUES(?, ?, ?, ?, ?, ?)
209
+ """,
210
+ (user_id, session_id, topic_id, text, json.dumps(metadata or {}), time.time()),
211
+ )
212
+
213
+ def retrieve(
214
+ self,
215
+ user_id: str,
216
+ session_id: Optional[str] = None,
217
+ topic_id: Optional[str] = None,
218
+ query: str = "",
219
+ k: int = 5,
220
+ ) -> List[Dict[str, Any]]:
221
+ clauses = ["user_id = ?"]
222
+ params: List[Any] = [user_id]
223
+ if session_id:
224
+ clauses.append("session_id = ?")
225
+ params.append(session_id)
226
+ if topic_id:
227
+ clauses.append("topic_id = ?")
228
+ params.append(topic_id)
229
+ params.append(max(1, k * 5))
230
+ sql = f"""
231
+ SELECT session_id, topic_id, text, metadata, created_at
232
+ FROM memory_entries
233
+ WHERE {' AND '.join(clauses)}
234
+ ORDER BY created_at DESC
235
+ LIMIT ?
236
+ """
237
+ with self._connect() as conn:
238
+ rows = conn.execute(sql, params).fetchall()
239
+ entries = [
240
+ {
241
+ "session_id": row[0],
242
+ "topic_id": row[1],
243
+ "text": row[2],
244
+ "metadata": json.loads(row[3] or "{}"),
245
+ "created_at": row[4],
246
+ }
247
+ for row in rows
248
+ ]
249
+ return _rank_entries(entries, session_id, topic_id, query, k)
250
+
251
+ def delete_user(self, user_id: str) -> bool:
252
+ with self._connect() as conn:
253
+ before = conn.total_changes
254
+ conn.execute("DELETE FROM user_memory WHERE user_id = ?", (user_id,))
255
+ conn.execute("DELETE FROM memory_entries WHERE user_id = ?", (user_id,))
256
+ return conn.total_changes > before
257
+
258
+ def _connect(self) -> sqlite3.Connection:
259
+ self.path.parent.mkdir(parents=True, exist_ok=True)
260
+ return sqlite3.connect(self.path)
261
+
262
+ def _init_schema(self) -> None:
263
+ with self._connect() as conn:
264
+ conn.execute(
265
+ """
266
+ CREATE TABLE IF NOT EXISTS user_memory(
267
+ user_id TEXT PRIMARY KEY,
268
+ payload TEXT NOT NULL,
269
+ updated_at REAL NOT NULL
270
+ )
271
+ """
272
+ )
273
+ conn.execute(
274
+ """
275
+ CREATE TABLE IF NOT EXISTS memory_entries(
276
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
277
+ user_id TEXT NOT NULL,
278
+ session_id TEXT NOT NULL,
279
+ topic_id TEXT NOT NULL,
280
+ text TEXT NOT NULL,
281
+ metadata TEXT NOT NULL,
282
+ created_at REAL NOT NULL
283
+ )
284
+ """
285
+ )
286
+ conn.execute(
287
+ "CREATE INDEX IF NOT EXISTS idx_entries_user_session_topic ON memory_entries(user_id, session_id, topic_id, created_at)"
288
+ )
289
+
290
+
291
+ class RedisBackend(MemoryBackend): # pragma: no cover - requires external Redis service.
292
+ """Redis backend using string payloads and per-user entry lists."""
293
+
294
+ def __init__(self, url: str = "redis://localhost:6379/0", cipher: Optional[MemoryCipher] = None):
295
+ try:
296
+ import redis
297
+ except Exception as exc: # pragma: no cover - optional dependency.
298
+ raise RuntimeError("Install redis to use RedisBackend: pip install smriti-ai[backends]") from exc
299
+ self.client = redis.Redis.from_url(url, decode_responses=True)
300
+ self.cipher = cipher or MemoryCipher()
301
+
302
+ def load(self, user_id: str) -> Optional[Dict[str, Any]]:
303
+ raw = self.client.get(self._payload_key(user_id))
304
+ if not raw:
305
+ return None
306
+ return self.cipher.unwrap(json.loads(raw))
307
+
308
+ def save(self, user_id: str, memory: Dict[str, Any]) -> None:
309
+ self.client.set(self._payload_key(user_id), json.dumps(self.cipher.wrap(memory)))
310
+
311
+ def add_entry(self, user_id: str, session_id: str, topic_id: str, text: str, metadata: Optional[Dict[str, Any]] = None) -> None:
312
+ self.client.lpush(self._entries_key(user_id), json.dumps(_entry(session_id, topic_id, text, metadata)))
313
+
314
+ def retrieve(self, user_id: str, session_id: Optional[str] = None, topic_id: Optional[str] = None, query: str = "", k: int = 5) -> List[Dict[str, Any]]:
315
+ raw_entries = self.client.lrange(self._entries_key(user_id), 0, max(0, k * 5 - 1))
316
+ entries = [json.loads(item) for item in raw_entries]
317
+ return _rank_entries(entries, session_id, topic_id, query, k)
318
+
319
+ def delete_user(self, user_id: str) -> bool:
320
+ return bool(self.client.delete(self._payload_key(user_id), self._entries_key(user_id)))
321
+
322
+ def _payload_key(self, user_id: str) -> str:
323
+ return f"smriti:user:{_safe_id(user_id)}:payload"
324
+
325
+ def _entries_key(self, user_id: str) -> str:
326
+ return f"smriti:user:{_safe_id(user_id)}:entries"
327
+
328
+
329
+ class PostgresBackend(MemoryBackend): # pragma: no cover - requires external Postgres service.
330
+ """Postgres backend using psycopg2 and indexed user/session/topic tables."""
331
+
332
+ def __init__(self, dsn: str, cipher: Optional[MemoryCipher] = None):
333
+ if not dsn:
334
+ raise ValueError("SMRITI_POSTGRES_DSN is required for PostgresBackend.")
335
+ try:
336
+ import psycopg2
337
+ except Exception as exc: # pragma: no cover - optional dependency.
338
+ raise RuntimeError("Install psycopg2-binary to use PostgresBackend: pip install smriti-ai[backends]") from exc
339
+ self._psycopg2 = psycopg2
340
+ self.dsn = dsn
341
+ self.cipher = cipher or MemoryCipher()
342
+ self._init_schema()
343
+
344
+ def load(self, user_id: str) -> Optional[Dict[str, Any]]:
345
+ with self._connect() as conn, conn.cursor() as cur:
346
+ cur.execute("SELECT payload FROM user_memory WHERE user_id = %s", (user_id,))
347
+ row = cur.fetchone()
348
+ if not row:
349
+ return None
350
+ return self.cipher.unwrap(row[0] if isinstance(row[0], dict) else json.loads(row[0]))
351
+
352
+ def save(self, user_id: str, memory: Dict[str, Any]) -> None:
353
+ payload = json.dumps(self.cipher.wrap(memory))
354
+ with self._connect() as conn, conn.cursor() as cur:
355
+ cur.execute(
356
+ """
357
+ INSERT INTO user_memory(user_id, payload, updated_at)
358
+ VALUES(%s, %s::jsonb, NOW())
359
+ ON CONFLICT(user_id) DO UPDATE SET payload=excluded.payload, updated_at=excluded.updated_at
360
+ """,
361
+ (user_id, payload),
362
+ )
363
+
364
+ def add_entry(self, user_id: str, session_id: str, topic_id: str, text: str, metadata: Optional[Dict[str, Any]] = None) -> None:
365
+ with self._connect() as conn, conn.cursor() as cur:
366
+ cur.execute(
367
+ """
368
+ INSERT INTO memory_entries(user_id, session_id, topic_id, text, metadata)
369
+ VALUES(%s, %s, %s, %s, %s::jsonb)
370
+ """,
371
+ (user_id, session_id, topic_id, text, json.dumps(metadata or {})),
372
+ )
373
+
374
+ def retrieve(self, user_id: str, session_id: Optional[str] = None, topic_id: Optional[str] = None, query: str = "", k: int = 5) -> List[Dict[str, Any]]:
375
+ clauses = ["user_id = %s"]
376
+ params: List[Any] = [user_id]
377
+ if session_id:
378
+ clauses.append("session_id = %s")
379
+ params.append(session_id)
380
+ if topic_id:
381
+ clauses.append("topic_id = %s")
382
+ params.append(topic_id)
383
+ params.append(max(1, k * 5))
384
+ sql = f"""
385
+ SELECT session_id, topic_id, text, metadata, EXTRACT(EPOCH FROM created_at)
386
+ FROM memory_entries
387
+ WHERE {' AND '.join(clauses)}
388
+ ORDER BY created_at DESC
389
+ LIMIT %s
390
+ """
391
+ with self._connect() as conn, conn.cursor() as cur:
392
+ cur.execute(sql, params)
393
+ rows = cur.fetchall()
394
+ entries = [
395
+ {
396
+ "session_id": row[0],
397
+ "topic_id": row[1],
398
+ "text": row[2],
399
+ "metadata": row[3] or {},
400
+ "created_at": float(row[4]),
401
+ }
402
+ for row in rows
403
+ ]
404
+ return _rank_entries(entries, session_id, topic_id, query, k)
405
+
406
+ def delete_user(self, user_id: str) -> bool:
407
+ with self._connect() as conn, conn.cursor() as cur:
408
+ cur.execute("DELETE FROM user_memory WHERE user_id = %s", (user_id,))
409
+ memory_deleted = cur.rowcount
410
+ cur.execute("DELETE FROM memory_entries WHERE user_id = %s", (user_id,))
411
+ entries_deleted = cur.rowcount
412
+ return bool(memory_deleted or entries_deleted)
413
+
414
+ def _connect(self):
415
+ return self._psycopg2.connect(self.dsn)
416
+
417
+ def _init_schema(self) -> None:
418
+ with self._connect() as conn, conn.cursor() as cur:
419
+ cur.execute(
420
+ """
421
+ CREATE TABLE IF NOT EXISTS user_memory(
422
+ user_id TEXT PRIMARY KEY,
423
+ payload JSONB NOT NULL,
424
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
425
+ )
426
+ """
427
+ )
428
+ cur.execute(
429
+ """
430
+ CREATE TABLE IF NOT EXISTS memory_entries(
431
+ id BIGSERIAL PRIMARY KEY,
432
+ user_id TEXT NOT NULL,
433
+ session_id TEXT NOT NULL,
434
+ topic_id TEXT NOT NULL,
435
+ text TEXT NOT NULL,
436
+ metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
437
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
438
+ )
439
+ """
440
+ )
441
+ cur.execute(
442
+ "CREATE INDEX IF NOT EXISTS idx_smriti_entries_user_session_topic ON memory_entries(user_id, session_id, topic_id, created_at DESC)"
443
+ )
444
+
445
+
446
+ def _fernet_key(secret: str) -> bytes:
447
+ raw = secret.encode("utf-8")
448
+ try:
449
+ base64.urlsafe_b64decode(raw)
450
+ if len(raw) == 44:
451
+ return raw
452
+ except Exception:
453
+ pass
454
+ return base64.urlsafe_b64encode(hashlib.sha256(raw).digest())
455
+
456
+
457
+ def _safe_id(value: str) -> str:
458
+ return re.sub(r"[^a-zA-Z0-9_.-]+", "_", value.strip()) or "default"
459
+
460
+
461
+ def _entry(session_id: str, topic_id: str, text: str, metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]:
462
+ return {
463
+ "session_id": session_id,
464
+ "topic_id": topic_id,
465
+ "text": text,
466
+ "metadata": metadata or {},
467
+ "created_at": time.time(),
468
+ }
469
+
470
+
471
+ def _rank_entries(
472
+ entries: List[Dict[str, Any]],
473
+ session_id: Optional[str],
474
+ topic_id: Optional[str],
475
+ query: str,
476
+ k: int,
477
+ ) -> List[Dict[str, Any]]:
478
+ scoped = [
479
+ entry
480
+ for entry in entries
481
+ if (not session_id or entry.get("session_id") == session_id)
482
+ and (not topic_id or entry.get("topic_id") == topic_id)
483
+ ]
484
+ if not query.strip():
485
+ return sorted(scoped, key=lambda item: item.get("created_at", 0), reverse=True)[:k]
486
+ q_terms = set(re.findall(r"[a-z0-9']+", query.lower()))
487
+ scored = []
488
+ for entry in scoped:
489
+ terms = set(re.findall(r"[a-z0-9']+", entry.get("text", "").lower()))
490
+ overlap = len(q_terms & terms) / max(1, len(q_terms | terms))
491
+ recency = entry.get("created_at", 0)
492
+ scored.append((overlap, recency, entry))
493
+ scored.sort(key=lambda item: (item[0], item[1]), reverse=True)
494
+ return [entry for _, _, entry in scored[:k]]