diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7fb9e93e2e308f81e841d5c2b49a1c5fb11293a1 --- /dev/null +++ b/README.md @@ -0,0 +1,187 @@ +--- +license: apache-2.0 +language: + - en +library_name: smriti-ai +tags: + - ai-agent + - memory + - small-language-models + - inference-time-augmentation + - semantic-search + - knowledge-graph + - identity-continuity + - rag +pipeline_tag: text-generation +--- + +# Smriti AI + +## What this is + +Smriti AI is a memory-augmented inference layer for small language models. It adds external memory, semantic retrieval, knowledge-graph recall, identity continuity, and privacy-ready memory deletion without changing base model weights. + +This repository layout is intended for a Hugging Face model-style deployment with a custom `handler.py`. The handler loads a base causal language model or calls a remote model endpoint, wraps it with Smriti AI memory, and returns model responses plus retrieved memories. + +## What this is not + +Smriti AI is not a newly trained foundation model. It is not a fine-tuned model unless a separate fine-tuned checkpoint is explicitly included. It is an inference-time wrapper around a base language model. + +Do not interpret this repository as a standalone model checkpoint. The base model is configured through `BASE_MODEL_ID` or `HF_ENDPOINT_URL`. + +## Research Lineage + +Smriti AI follows four principles: + +- **External memory**: conversational facts live outside model weights in a persistent, inspectable store. +- **Training-free recall**: relevant facts are retrieved and injected at inference time without fine-tuning the base model. +- **Identity continuity**: persona evidence is tracked as an embedding fingerprint so outputs can be checked for drift. +- **Small-model augmentation**: small causal language models can become more useful when paired with explicit memory and retrieval. + +Historical GodelAI-Lite results were measured on an earlier system. Current Smriti AI results are measured separately and should not be conflated with historical results. + +## Architecture + +```text +User request + -> Smriti AI handler + -> memory retrieval + -> graph retrieval + -> identity context + -> base model inference + -> response + -> memory write/update +``` + +The handler supports JSON, SQLite, Redis, and Postgres memory backends. For production, use Redis/Postgres or another external durable store. Do not store private user memory in the Hugging Face model repository. + +## Supported base models + +Smriti AI is model-agnostic for Hugging Face causal language models. + +Supported families depend on the installed `transformers` version and endpoint hardware: + +- Gemma-style causal LMs when available, including the current benchmark path `google/gemma-4-E2B-it`. +- Llama/Phi/Mistral/Qwen-style causal LMs if supported by the runtime environment. +- Tiny CPU-safe local smoke-test models such as `sshleifer/tiny-gpt2` for handler validation only. + +Tiny models are useful for endpoint plumbing tests. They are not public Smriti AI quality benchmarks. + +## Evaluation + +Current local Gemma 4-only benchmark artifacts in the main Smriti AI repository report: + +| Evaluation | Baseline Recall | Smriti AI Recall | Notes | +|---|---:|---:|---| +| Gemma-style three-fact protocol | 0/3 | 3/3 | Smriti AI recalls all injected facts after distractors. | +| Five-mode comparison | 0/3 | 3/3 | TF-IDF, Semantic, Semantic+Graph, and Semantic+Graph+Identity all recall 3/3 in the checked-in run. | +| Original broader protocol rerun | 0/3 | 3/3 | Overall average improves from 0.524 to 0.832 (`+58.9%`) in the current local Gemma 4 CPU rerun. | + +Historical GodelAI-Lite results were measured on an earlier system. Current Smriti AI results are measured separately and should not be conflated with historical results. + +## Privacy + +Smriti AI stores user memory. Treat it as user data. + +- Memory can be encrypted by setting `SMRITI_ENCRYPTION_KEY`. +- `delete_memory` is supported by the handler. +- Production deployments should use external memory storage such as Redis/Postgres. +- Do not store private user memory in the Hugging Face model repository. +- Public/demo deployments should not receive real PII. + +## Limitations + +- Retrieval quality depends on the quality and specificity of stored memory. +- Public/demo deployments should not receive real PII. +- Durable memory requires external backend or persistent endpoint storage. +- Latency depends on the base model, backend, retrieval mode, and endpoint hardware. +- A tiny CPU demo model validates handler plumbing but will not produce Gemma-quality answers. +- If no `BASE_MODEL_ID` or `HF_ENDPOINT_URL` is configured, the handler falls back to memory-only responses. + +## Environment variables + +| Variable | Purpose | +|---|---| +| `BASE_MODEL_ID` | Hugging Face model ID to load inside the endpoint. | +| `HF_ENDPOINT_URL` | Optional remote model endpoint URL. If set, the handler calls this URL instead of loading a local base model. | +| `HF_TOKEN` | Token for gated/private base models or protected remote endpoints. | +| `SMRITI_MEMORY_BACKEND` | `json`, `sqlite`, `redis`, or `postgres`. | +| `SMRITI_MEMORY_PATH` | JSON user-memory directory or SQLite file path. | +| `REDIS_URL` | External Redis URL. Takes precedence when present. | +| `POSTGRES_DSN` | External Postgres DSN. Takes precedence when present and Redis is not configured. | +| `SMRITI_ENCRYPTION_KEY` | Memory encryption key. Do not commit it. | +| `SMRITI_RETRIEVAL_MODE` | `tfidf`, `semantic`, `semantic_graph`, or `semantic_graph_identity`. | +| `SMRITI_PUBLIC_DEMO` | `true` or `false`. Use `true` only for non-PII demos. | +| `SMRITI_MAX_MEMORY_ENTRIES` | Maximum retained entries per user/topic. | + +## How to call the endpoint + +### Chat / fact injection + +```json +{ + "inputs": { + "operation": "chat", + "user_id": "customer-123", + "message": "My name is Alex and I am a marine biologist.", + "retrieval_mode": "semantic_graph_identity" + }, + "parameters": { + "max_new_tokens": 256, + "temperature": 0.7, + "top_p": 0.9, + "return_memories": true + } +} +``` + +### Recall + +```json +{ + "inputs": { + "operation": "chat", + "user_id": "customer-123", + "message": "What do you remember about me?", + "retrieval_mode": "semantic_graph_identity" + }, + "parameters": { + "return_memories": true + } +} +``` + +### Delete memory + +```json +{ + "inputs": { + "operation": "delete_memory", + "user_id": "customer-123" + } +} +``` + +### Health + +```json +{ + "inputs": { + "operation": "health" + } +} +``` + +## Local test + +```bash +pip install -r requirements.txt +BASE_MODEL_ID=sshleifer/tiny-gpt2 \ +SMRITI_MEMORY_BACKEND=json \ +SMRITI_MEMORY_PATH=/tmp/smriti_hf_test.json \ +python test_handler_local.py +``` + +## Custom-container deployment + +If the standard Hugging Face handler is insufficient for your model size, CUDA libraries, Redis client policy, or enterprise network requirements, deploy the same files in a custom container. Use the main Smriti AI repository Dockerfiles as the starting point, install this handler, and expose a compatible HTTP API through Hugging Face Inference Endpoints custom container support. diff --git a/config.json b/config.json new file mode 100644 index 0000000000000000000000000000000000000000..59719376211742360ad6ca6635bf3e2304ea2f36 --- /dev/null +++ b/config.json @@ -0,0 +1,11 @@ +{ + "project": "Smriti AI", + "base_model": "REPLACE_WITH_BASE_MODEL_ID", + "retrieval_mode": "semantic_graph_identity", + "memory_backend": "json", + "public_demo": false, + "max_memory_entries": 1000, + "enable_identity": true, + "enable_graph": true, + "enable_encryption": true +} diff --git a/examples/request_delete.json b/examples/request_delete.json new file mode 100644 index 0000000000000000000000000000000000000000..eed81bb28b1582369397ebb080ed110c40dd967b --- /dev/null +++ b/examples/request_delete.json @@ -0,0 +1,6 @@ +{ + "inputs": { + "operation": "delete_memory", + "user_id": "demo-user" + } +} diff --git a/examples/request_distractor.json b/examples/request_distractor.json new file mode 100644 index 0000000000000000000000000000000000000000..d74a5a03360bb45793ac7c14ca46c2e56bcf56d2 --- /dev/null +++ b/examples/request_distractor.json @@ -0,0 +1,8 @@ +{ + "inputs": { + "operation": "chat", + "user_id": "demo-user", + "message": "What is the capital of France?", + "retrieval_mode": "semantic_graph_identity" + } +} diff --git a/examples/request_memory_inject.json b/examples/request_memory_inject.json new file mode 100644 index 0000000000000000000000000000000000000000..d86fa40382ed14d656f94bc14079eb26f717fc7f --- /dev/null +++ b/examples/request_memory_inject.json @@ -0,0 +1,11 @@ +{ + "inputs": { + "operation": "chat", + "user_id": "demo-user", + "message": "My name is Alex and I am a marine biologist based in Hawaii.", + "retrieval_mode": "semantic_graph_identity" + }, + "parameters": { + "return_memories": true + } +} diff --git a/examples/request_recall.json b/examples/request_recall.json new file mode 100644 index 0000000000000000000000000000000000000000..e49d1eb88623068dd14b9a2d266bb2dd7415f906 --- /dev/null +++ b/examples/request_recall.json @@ -0,0 +1,11 @@ +{ + "inputs": { + "operation": "chat", + "user_id": "demo-user", + "message": "What do you remember about me?", + "retrieval_mode": "semantic_graph_identity" + }, + "parameters": { + "return_memories": true + } +} diff --git a/handler.py b/handler.py new file mode 100644 index 0000000000000000000000000000000000000000..31cfa35aae3af4370e0dc680af133de4b977619c --- /dev/null +++ b/handler.py @@ -0,0 +1,647 @@ +"""Hugging Face custom inference handler for Smriti AI. + +This file is intentionally deployment glue. Core memory, retrieval, graph, and +identity behavior comes from the installed `smriti` package. +""" + +from __future__ import annotations + +import json +import logging +import os +import re +import sys +import time +import urllib.error +import urllib.request +from pathlib import Path +from threading import RLock +from typing import Any, Dict, List, Optional, Tuple + +VENDOR_SRC = Path(__file__).resolve().parent / "smriti_vendor" +if VENDOR_SRC.exists() and str(VENDOR_SRC) not in sys.path: + sys.path.insert(0, str(VENDOR_SRC)) + +from smriti import IdentityFingerprint, MemPalaceLite, SmritiAILite # noqa: E402 +from smriti.backends import ( # noqa: E402 + JsonBackend, + MemoryBackend, + MemoryCipher, + PostgresBackend, + RedisBackend, + SqliteBackend, +) + +LOGGER = logging.getLogger("smriti.hf_handler") +if not LOGGER.handlers: + logging.basicConfig(level=os.getenv("SMRITI_LOG_LEVEL", "INFO")) + +DEFAULT_CONFIG = { + "project": "Smriti AI", + "base_model": "REPLACE_WITH_BASE_MODEL_ID", + "retrieval_mode": "semantic_graph_identity", + "memory_backend": "json", + "public_demo": False, + "max_memory_entries": 1000, + "enable_identity": True, + "enable_graph": True, + "enable_encryption": True, +} + + +class EndpointHandler: + """Hugging Face custom inference endpoint handler.""" + + def __init__(self, path: str = ""): + self.root = _resolve_root(path) + self.config = _load_config(self.root / "config.json") + self.lock = RLock() + self.memories: Dict[str, MemPalaceLite] = {} + self.identities: Dict[str, IdentityFingerprint] = {} + self.backend_warning: Optional[str] = None + + self.base_model_id = _clean_model_id( + os.getenv("BASE_MODEL_ID") or self.config.get("base_model", "") + ) + self.endpoint_url = os.getenv("HF_ENDPOINT_URL", "").strip() + self.hf_token = os.getenv("HF_TOKEN", "").strip() + self.default_retrieval_mode = os.getenv( + "SMRITI_RETRIEVAL_MODE", + str(self.config.get("retrieval_mode", "semantic_graph_identity")), + ) + self.max_memory_entries = _int_env( + "SMRITI_MAX_MEMORY_ENTRIES", + int(self.config.get("max_memory_entries", 1000)), + ) + self.public_demo = _bool_env("SMRITI_PUBLIC_DEMO", bool(self.config.get("public_demo", False))) + self.enable_graph_default = bool(self.config.get("enable_graph", True)) + self.enable_identity_default = bool(self.config.get("enable_identity", True)) + self.enable_encryption = bool(self.config.get("enable_encryption", True)) + + self.backend, self.backend_name = self._init_backend() + self.model = None + self.tokenizer = None + self.device = "cpu" + if self.endpoint_url: + LOGGER.info( + "Smriti AI handler using remote model endpoint; backend=%s retrieval=%s", + self.backend_name, + self.default_retrieval_mode, + ) + elif self.base_model_id: + self._load_local_model(self.base_model_id) + else: + LOGGER.warning( + "No BASE_MODEL_ID or HF_ENDPOINT_URL configured; handler will run memory-only." + ) + + LOGGER.info( + "Smriti AI handler ready: base_model=%s remote_endpoint=%s backend=%s retrieval=%s encryption=%s public_demo=%s", + self.base_model_id or "memory-only", + bool(self.endpoint_url), + self.backend_name, + self.default_retrieval_mode, + self.enable_encryption and bool(os.getenv("SMRITI_ENCRYPTION_KEY") or os.getenv("SMRITI_MEMORY_KEY")), + self.public_demo, + ) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + start = time.perf_counter() + try: + inputs, parameters = _normalize_request(data) + operation = str(inputs.get("operation", "chat")).lower() + + if operation == "health": + return self._health(start) + if operation == "delete_memory": + return self._delete_memory(inputs, start) + if operation != "chat": + return _error(f"Unsupported operation: {operation}", start) + + return self._chat(inputs, parameters, start) + except Exception as exc: # Defensive boundary for endpoint runtimes. + LOGGER.exception("Unhandled Smriti AI handler error") + return _error(f"handler_error:{exc.__class__.__name__}: {exc}", start) + + # ------------------------------------------------------------------ + # Operation handlers + # ------------------------------------------------------------------ + + def _chat( + self, + inputs: Dict[str, Any], + parameters: Dict[str, Any], + start: float, + ) -> Dict[str, Any]: + user_id = str(inputs.get("user_id") or "").strip() + message = str(inputs.get("message") or "").strip() + topic_id = str(inputs.get("topic_id") or "general").strip() or "general" + if not user_id: + return _error("user_id is required", start) + if not message: + return _error("message is required for chat operation", start) + + retrieval_mode = str(inputs.get("retrieval_mode") or self.default_retrieval_mode) + base_retrieval = _base_retrieval_mode(retrieval_mode) + include_graph = self.enable_graph_default and "graph" in retrieval_mode + identity_enabled = self.enable_identity_default and "identity" in retrieval_mode + + with self.lock: + memory = self._get_memory(user_id, topic_id, base_retrieval) + context, retrieved_memories, graph_facts, retrieval_warning = self._retrieve_context( + memory, + user_id, + topic_id, + message, + include_graph, + ) + identity = self._get_identity(user_id, identity_enabled) + agent = SmritiAILite( + model=self.model, + tokenizer=self.tokenizer, + retrieval_mode=base_retrieval, + session_id=user_id, + topic_id=topic_id, + memory=memory, + identity=identity, + auto_device=False, + ) + agent.build_prompt = lambda user_input: _build_prompt( + agent, + memory, + user_id, + topic_id, + user_input, + include_graph, + identity_enabled, + ) + + generation_calls = 0 + + def generate(prompt: str, max_tokens: int = 256) -> str: + nonlocal generation_calls + generation_calls += 1 + return self._generate_text(prompt, parameters, max_tokens=max_tokens) + + agent._generate = generate # type: ignore[method-assign] + try: + response = agent.chat(message) + except Exception as exc: + LOGGER.exception("Model generation failed") + return _error(f"model_generation_failed:{exc.__class__.__name__}: {exc}", start) + + response = _stabilize_recall_answer(message, response, retrieved_memories, graph_facts) + _replace_last_assistant_history(memory, response) + + identity_check = agent.identity.evaluate_output(response) if identity_enabled else None + save_warning = self._save_memory(user_id, memory) + + warnings = [item for item in [self.backend_warning, retrieval_warning, save_warning] if item] + return { + "response": response, + "retrieved_memories": retrieved_memories, + "graph_facts": graph_facts, + "identity": { + "enabled": identity_enabled, + "drift_score": float(identity_check.distance) if identity_check else 0.0, + "refinement_triggered": generation_calls > 1, + }, + "latency_ms": round((time.perf_counter() - start) * 1000, 3), + "backend": self.backend_name, + "retrieval_mode": retrieval_mode, + "warnings": warnings, + } + + def _delete_memory(self, inputs: Dict[str, Any], start: float) -> Dict[str, Any]: + user_id = str(inputs.get("user_id") or "").strip() + if not user_id: + return _error("user_id is required for delete_memory operation", start) + with self.lock: + existed_cache = self.memories.pop(user_id, None) is not None + self.identities.pop(user_id, None) + try: + deleted_backend = self.backend.delete_user(user_id) + except Exception as exc: + LOGGER.exception("Memory backend delete failed") + return _error(f"backend_delete_failed:{exc.__class__.__name__}: {exc}", start) + return { + "deleted": bool(existed_cache or deleted_backend), + "user_id": user_id, + "latency_ms": round((time.perf_counter() - start) * 1000, 3), + "backend": self.backend_name, + } + + def _health(self, start: float) -> Dict[str, Any]: + return { + "status": "ok", + "project": "Smriti AI", + "base_model": self.base_model_id or ("remote-endpoint" if self.endpoint_url else "memory-only"), + "backend": self.backend_name, + "retrieval_mode": self.default_retrieval_mode, + "latency_ms": round((time.perf_counter() - start) * 1000, 3), + } + + # ------------------------------------------------------------------ + # Runtime setup + # ------------------------------------------------------------------ + + def _init_backend(self) -> Tuple[MemoryBackend, str]: + encryption_key = os.getenv("SMRITI_ENCRYPTION_KEY") or os.getenv("SMRITI_MEMORY_KEY") + if encryption_key: + os.environ["SMRITI_MEMORY_KEY"] = encryption_key + cipher = MemoryCipher(encryption_key if self.enable_encryption else None) + + redis_url = os.getenv("REDIS_URL") or os.getenv("SMRITI_REDIS_URL") + postgres_dsn = os.getenv("POSTGRES_DSN") or os.getenv("SMRITI_POSTGRES_DSN") + selected = (os.getenv("SMRITI_MEMORY_BACKEND") or self.config.get("memory_backend") or "json").lower() + memory_path = os.getenv("SMRITI_MEMORY_PATH", "/tmp/smriti_hf_memory") + + if redis_url: + return RedisBackend(url=redis_url, cipher=cipher), "redis" + if postgres_dsn: + return PostgresBackend(dsn=postgres_dsn, cipher=cipher), "postgres" + if selected == "redis": + return RedisBackend(url=redis_url or "redis://localhost:6379/0", cipher=cipher), "redis" + if selected in {"postgres", "postgresql"}: + return PostgresBackend(dsn=postgres_dsn or "", cipher=cipher), "postgres" + if selected == "sqlite": + return SqliteBackend(path=memory_path, cipher=cipher), "sqlite" + return JsonBackend(root=_json_root(memory_path), cipher=cipher), "json" + + def _load_local_model(self, model_id: str) -> None: + try: + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + except Exception as exc: + raise RuntimeError("Install torch and transformers to load a local base model.") from exc + + self.device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.float32 + if self.device == "cuda": + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + + kwargs = {"token": self.hf_token} if self.hf_token else {} + self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) + if getattr(self.tokenizer, "pad_token_id", None) is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + try: + self.model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype, **kwargs) + except TypeError: + self.model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **kwargs) + self.model.to(self.device) + self.model.eval() + LOGGER.info("Loaded local base model %s on %s", model_id, self.device) + + # ------------------------------------------------------------------ + # Memory and generation helpers + # ------------------------------------------------------------------ + + def _get_memory(self, user_id: str, topic_id: str, retrieval_mode: str) -> MemPalaceLite: + self.backend_warning = None + if user_id not in self.memories: + state = None + try: + state = self.backend.load(user_id) + except Exception as exc: + LOGGER.exception("Memory backend load failed; starting empty memory") + self.backend_warning = f"backend_load_failed:{exc.__class__.__name__}" + if state: + memory = MemPalaceLite.from_dict( + state, + retrieval_mode=retrieval_mode, + session_id=user_id, + topic_id=topic_id, + max_facts=self.max_memory_entries, + max_entries_per_topic=self.max_memory_entries, + ) + else: + memory = MemPalaceLite( + retrieval_mode=retrieval_mode, + session_id=user_id, + topic_id=topic_id, + max_facts=self.max_memory_entries, + max_entries_per_topic=self.max_memory_entries, + ) + self.memories[user_id] = memory + memory = self.memories[user_id] + if memory.retrieval_mode != retrieval_mode: + memory = MemPalaceLite.from_dict( + memory.to_dict(), + retrieval_mode=retrieval_mode, + session_id=user_id, + topic_id=topic_id, + max_facts=self.max_memory_entries, + max_entries_per_topic=self.max_memory_entries, + ) + self.memories[user_id] = memory + memory.session_id = user_id + memory.topic_id = topic_id + return memory + + def _get_identity(self, user_id: str, enabled: bool) -> IdentityFingerprint: + if user_id not in self.identities: + threshold = 0.35 if enabled else 2.0 + self.identities[user_id] = IdentityFingerprint( + role="helpful AI assistant with persistent memory", + threshold=threshold, + ) + identity = self.identities[user_id] + if not enabled: + identity.threshold = 2.0 + return identity + + def _retrieve_context( + self, + memory: MemPalaceLite, + user_id: str, + topic_id: str, + message: str, + include_graph: bool, + ) -> Tuple[str, List[str], List[str], Optional[str]]: + try: + context = memory.get_context( + query=message, + session_id=user_id, + topic_id=topic_id, + include_graph=include_graph, + ) + retrieved_memories = memory.retrieve_facts( + message, + k=5, + session_id=user_id, + topic_id=topic_id, + ) + graph_facts = _section_bullets(context, "[RELATED GRAPH FACTS]") if include_graph else [] + return context, retrieved_memories, graph_facts, None + except Exception as exc: + LOGGER.exception("Memory retrieval failed") + return "", [], [], f"retrieval_failed:{exc.__class__.__name__}" + + def _save_memory(self, user_id: str, memory: MemPalaceLite) -> Optional[str]: + try: + self.backend.save(user_id, memory.to_dict()) + return None + except Exception as exc: + LOGGER.exception("Memory backend save failed") + return f"backend_save_failed:{exc.__class__.__name__}" + + def _generate_text(self, prompt: str, parameters: Dict[str, Any], max_tokens: int = 256) -> str: + max_new_tokens = int(parameters.get("max_new_tokens", max_tokens) or max_tokens) + temperature = float(parameters.get("temperature", 0.7)) + top_p = float(parameters.get("top_p", 0.9)) + if self.endpoint_url: + return self._generate_remote(prompt, max_new_tokens, temperature, top_p) + if self.model is not None and self.tokenizer is not None: + return self._generate_local(prompt, max_new_tokens, temperature, top_p) + return _memory_only_answer(prompt) + + def _generate_local( + self, + prompt: str, + max_new_tokens: int, + temperature: float, + top_p: float, + ) -> str: + import torch + + messages = [{"role": "user", "content": prompt}] + try: + formatted = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + except Exception: + formatted = prompt + inputs = self.tokenizer( + formatted, + return_tensors="pt", + truncation=True, + max_length=2048, + ) + inputs = {key: value.to(self.device) for key, value in inputs.items()} + generate_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": temperature > 0, + "pad_token_id": getattr(self.tokenizer, "eos_token_id", None), + } + if temperature > 0: + generate_kwargs["temperature"] = temperature + generate_kwargs["top_p"] = top_p + with torch.inference_mode(): + output = self.model.generate(**inputs, **generate_kwargs) + return self.tokenizer.decode( + output[0, inputs["input_ids"].shape[1] :].detach().cpu(), + skip_special_tokens=True, + ).strip() + + def _generate_remote( + self, + prompt: str, + max_new_tokens: int, + temperature: float, + top_p: float, + ) -> str: + payload = { + "inputs": prompt, + "parameters": { + "max_new_tokens": max_new_tokens, + "temperature": temperature, + "top_p": top_p, + }, + } + headers = {"Content-Type": "application/json"} + if self.hf_token: + headers["Authorization"] = f"Bearer {self.hf_token}" + request = urllib.request.Request( + self.endpoint_url, + data=json.dumps(payload).encode("utf-8"), + headers=headers, + method="POST", + ) + try: + with urllib.request.urlopen(request, timeout=120) as response: # noqa: S310 + raw = response.read().decode("utf-8") + except urllib.error.HTTPError as exc: + body = exc.read().decode("utf-8", errors="replace") + raise RuntimeError(f"remote endpoint HTTP {exc.code}: {body[:300]}") from exc + parsed = json.loads(raw) + return _extract_generated_text(parsed) + + +# ---------------------------------------------------------------------- +# Request, context, and formatting helpers +# ---------------------------------------------------------------------- + + +def _resolve_root(path: str) -> Path: + if path: + root = Path(path).resolve() + return root.parent if root.is_file() else root + return Path(__file__).resolve().parent + + +def _load_config(path: Path) -> Dict[str, Any]: + if not path.exists(): + return dict(DEFAULT_CONFIG) + data = json.loads(path.read_text(encoding="utf-8")) + config = dict(DEFAULT_CONFIG) + config.update(data) + return config + + +def _normalize_request(data: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + if not isinstance(data, dict): + raise ValueError("Request body must be a JSON object.") + if "inputs" in data: + inputs = data.get("inputs") or {} + if isinstance(inputs, str): + inputs = {"message": inputs} + parameters = data.get("parameters") or {} + else: + inputs = data + parameters = data.get("parameters") or {} + if not isinstance(inputs, dict) or not isinstance(parameters, dict): + raise ValueError("inputs and parameters must be JSON objects.") + return inputs, parameters + + +def _base_retrieval_mode(mode: str) -> str: + return "tfidf" if str(mode).lower().startswith("tfidf") else "semantic" + + +def _build_prompt( + agent: SmritiAILite, + memory: MemPalaceLite, + user_id: str, + topic_id: str, + user_input: str, + include_graph: bool, + identity_enabled: bool, +) -> str: + identity = agent.identity.get_identity_prompt() if identity_enabled else "" + context = memory.get_context( + query=user_input, + session_id=user_id, + topic_id=topic_id, + include_graph=include_graph, + ) + parts = [part for part in [identity.strip(), context.strip(), user_input.strip()] if part] + return "\n\n".join(parts) + + +def _section_bullets(context: str, heading: str) -> List[str]: + if heading not in context: + return [] + after = context.split(heading, 1)[1] + chunks = re.split(r"\n\[[A-Z ]+\]", after, maxsplit=1) + section = chunks[0] + bullets = [] + for line in section.splitlines(): + cleaned = line.strip() + if cleaned.startswith("*"): + bullets.append(cleaned.lstrip("* ").strip()) + return bullets + + +def _memory_only_answer(prompt: str) -> str: + facts = _section_bullets(prompt, "[REMEMBERED FACTS]") + graph = _section_bullets(prompt, "[RELATED GRAPH FACTS]") + combined = facts + [item for item in graph if item not in facts] + if combined: + return "I remember: " + "; ".join(combined[:5]) + return "Memory updated. No prior relevant context was found." + + +def _is_recall_query(message: str) -> bool: + lowered = message.lower() + return any( + phrase in lowered + for phrase in [ + "remember", + "what do you know about me", + "who am i", + "where do i work", + "what is my name", + "what do i do", + ] + ) + + +def _stabilize_recall_answer( + message: str, + response: str, + retrieved_memories: List[str], + graph_facts: List[str], +) -> str: + if not _is_recall_query(message): + return response + combined = retrieved_memories + [item for item in graph_facts if item not in retrieved_memories] + if not combined: + return response + if _mentions_memory_terms(response, combined): + return response + return "I remember: " + "; ".join(combined[:5]) + + +def _mentions_memory_terms(response: str, memories: List[str]) -> bool: + response_terms = set(re.findall(r"[a-z0-9']{4,}", response.lower())) + memory_terms = set() + for memory in memories: + memory_terms.update(re.findall(r"[a-z0-9']{4,}", memory.lower())) + return bool(response_terms & memory_terms) + + +def _replace_last_assistant_history(memory: MemPalaceLite, response: str) -> None: + if memory.history and memory.history[-1].category == "assistant_output": + memory.history[-1].content = "Assistant: " + response[:200] + + +def _extract_generated_text(parsed: Any) -> str: + if isinstance(parsed, list) and parsed: + return _extract_generated_text(parsed[0]) + if isinstance(parsed, dict): + for key in ["generated_text", "response", "text", "output"]: + value = parsed.get(key) + if isinstance(value, str): + return value.strip() + if "outputs" in parsed: + return _extract_generated_text(parsed["outputs"]) + if isinstance(parsed, str): + return parsed.strip() + raise RuntimeError("Remote endpoint did not return generated text.") + + +def _json_root(memory_path: str) -> Path: + path = Path(memory_path) + if path.suffix.lower() in {".json", ".jsonl"}: + return path.with_suffix("") + return path + + +def _clean_model_id(value: str) -> str: + value = (value or "").strip() + if not value or value == "REPLACE_WITH_BASE_MODEL_ID": + return "" + return value + + +def _bool_env(name: str, default: bool) -> bool: + raw = os.getenv(name) + if raw is None: + return default + return raw.strip().lower() in {"1", "true", "yes", "on"} + + +def _int_env(name: str, default: int) -> int: + try: + return int(os.getenv(name, str(default))) + except ValueError: + return default + + +def _error(message: str, start: float) -> Dict[str, Any]: + return { + "error": message, + "latency_ms": round((time.perf_counter() - start) * 1000, 3), + } diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..26fd7aad7369baa7b932fd4857819c41473463b3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +# Smriti AI is not yet assumed to be published on PyPI for this deployment artifact. +# Until it is published, install the package directly from the GitHub repository. +git+https://github.com/Luciferai04/smriti-ai.git + +# After PyPI publication, replace the GitHub line above with: +# smriti-ai>=0.3.1 + +transformers +accelerate +torch +sentence-transformers +faiss-cpu +networkx +cryptography +pydantic +redis +psycopg2-binary +huggingface_hub +requests diff --git a/smriti_endpoint_config.yaml b/smriti_endpoint_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ced8b831f4a55571a29a8e6d3f60076b010d10ae --- /dev/null +++ b/smriti_endpoint_config.yaml @@ -0,0 +1,35 @@ +# Smriti AI Hugging Face Inference Endpoint configuration template. +# Values here are documentation defaults. Set real values as endpoint environment +# variables or managed secrets, not as committed plaintext. + +BASE_MODEL_ID: "" +HF_ENDPOINT_URL: "" +HF_TOKEN: "" +SMRITI_MEMORY_BACKEND: "json" +SMRITI_MEMORY_PATH: "/data/smriti_memory" +REDIS_URL: "" +POSTGRES_DSN: "" +SMRITI_ENCRYPTION_KEY: "" +SMRITI_RETRIEVAL_MODE: "semantic_graph_identity" +SMRITI_PUBLIC_DEMO: "false" +SMRITI_MAX_MEMORY_ENTRIES: "1000" + +warnings: + - Do not commit HF_TOKEN. + - Do not commit SMRITI_ENCRYPTION_KEY. + - Production memory should use Redis/Postgres or another external durable storage service. + - The Hugging Face model repository should not contain user memory files. + - Public demo endpoints should not receive real PII. + +variables: + BASE_MODEL_ID: Hugging Face model ID to load locally inside the endpoint. + HF_ENDPOINT_URL: Optional remote model endpoint URL. If set, Smriti calls it instead of loading BASE_MODEL_ID locally. + HF_TOKEN: Hugging Face token for gated/private base models or protected remote endpoints. + SMRITI_MEMORY_BACKEND: json | sqlite | redis | postgres. + SMRITI_MEMORY_PATH: Path for JSON user-memory directory or SQLite database file. + REDIS_URL: External Redis URL. Takes precedence when present. + POSTGRES_DSN: External Postgres DSN. Takes precedence when present and REDIS_URL is empty. + SMRITI_ENCRYPTION_KEY: Encryption key for user memory. Maps to Smriti's SMRITI_MEMORY_KEY. + SMRITI_RETRIEVAL_MODE: tfidf | semantic | semantic_graph | semantic_graph_identity. + SMRITI_PUBLIC_DEMO: true | false. Use true only for non-PII demos. + SMRITI_MAX_MEMORY_ENTRIES: Maximum fact entries retained per user/topic. diff --git a/smriti_vendor/mempalace/__init__.py b/smriti_vendor/mempalace/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f439815c1068f8952def5b094911b31b25bb28d5 --- /dev/null +++ b/smriti_vendor/mempalace/__init__.py @@ -0,0 +1,3 @@ +"""Backward-compatible imports for the renamed :mod:`smriti` package.""" + +from smriti import * # noqa: F401,F403 diff --git a/smriti_vendor/mempalace/__pycache__/__init__.cpython-310.pyc b/smriti_vendor/mempalace/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f85ce92a777f5adee5e723aae1332f72ed6e637 Binary files /dev/null and b/smriti_vendor/mempalace/__pycache__/__init__.cpython-310.pyc differ diff --git a/smriti_vendor/mempalace/__pycache__/agent.cpython-310.pyc b/smriti_vendor/mempalace/__pycache__/agent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ebc3b2b8d432dfe0240ef03edc9f825c3b2c114 Binary files /dev/null and b/smriti_vendor/mempalace/__pycache__/agent.cpython-310.pyc differ diff --git a/smriti_vendor/mempalace/__pycache__/api.cpython-310.pyc b/smriti_vendor/mempalace/__pycache__/api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69c79106cb4cc599b9500ec59e584f5137ef702c Binary files /dev/null and b/smriti_vendor/mempalace/__pycache__/api.cpython-310.pyc differ diff --git a/smriti_vendor/mempalace/__pycache__/cli.cpython-310.pyc b/smriti_vendor/mempalace/__pycache__/cli.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b27eb3c4e6c7f3222049ed2c4f72a4f68565903 Binary files /dev/null and b/smriti_vendor/mempalace/__pycache__/cli.cpython-310.pyc differ diff --git a/smriti_vendor/mempalace/__pycache__/core.cpython-310.pyc b/smriti_vendor/mempalace/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..503ba276f1008f0af6ca61fecb62d25af5a7be71 Binary files /dev/null and b/smriti_vendor/mempalace/__pycache__/core.cpython-310.pyc differ diff --git a/smriti_vendor/mempalace/__pycache__/gifp.cpython-310.pyc b/smriti_vendor/mempalace/__pycache__/gifp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c66368889f621f9ded286b2e6b92a9e027ab333 Binary files /dev/null and b/smriti_vendor/mempalace/__pycache__/gifp.cpython-310.pyc differ diff --git a/smriti_vendor/mempalace/__pycache__/identity_fingerprint.cpython-310.pyc b/smriti_vendor/mempalace/__pycache__/identity_fingerprint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..471178c996a83c421134dbf041b80b4bc959296c Binary files /dev/null and b/smriti_vendor/mempalace/__pycache__/identity_fingerprint.cpython-310.pyc differ diff --git a/smriti_vendor/mempalace/__pycache__/knowledge_graph.cpython-310.pyc b/smriti_vendor/mempalace/__pycache__/knowledge_graph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6763c01c3185a486dffa2dba1e2d32b54ff14f7 Binary files /dev/null and b/smriti_vendor/mempalace/__pycache__/knowledge_graph.cpython-310.pyc differ diff --git a/smriti_vendor/mempalace/__pycache__/macp.cpython-310.pyc b/smriti_vendor/mempalace/__pycache__/macp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a12a88d73225fc811671be2f2a0132accea8db1f Binary files /dev/null and b/smriti_vendor/mempalace/__pycache__/macp.cpython-310.pyc differ diff --git a/smriti_vendor/mempalace/__pycache__/mem_palace.cpython-310.pyc b/smriti_vendor/mempalace/__pycache__/mem_palace.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf3b9f999fd971d14dd35e28209a544ee3e9ba66 Binary files /dev/null and b/smriti_vendor/mempalace/__pycache__/mem_palace.cpython-310.pyc differ diff --git a/smriti_vendor/mempalace/__pycache__/semantic_memory.cpython-310.pyc b/smriti_vendor/mempalace/__pycache__/semantic_memory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac64020c3498295c917e079e4b1c088fc588ae6e Binary files /dev/null and b/smriti_vendor/mempalace/__pycache__/semantic_memory.cpython-310.pyc differ diff --git a/smriti_vendor/mempalace/agent.py b/smriti_vendor/mempalace/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..0958d39fdad3b4ee4e62068e7574bb6e0874de7d --- /dev/null +++ b/smriti_vendor/mempalace/agent.py @@ -0,0 +1,3 @@ +"""Compatibility wrapper for :mod:`smriti.agent`.""" + +from smriti.agent import * # noqa: F401,F403 diff --git a/smriti_vendor/mempalace/api.py b/smriti_vendor/mempalace/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a5496bf890d26e0ec649cc7b047a977eb29d144c --- /dev/null +++ b/smriti_vendor/mempalace/api.py @@ -0,0 +1,3 @@ +"""Compatibility wrapper for :mod:`smriti.api`.""" + +from smriti.api import * # noqa: F401,F403 diff --git a/smriti_vendor/mempalace/cli.py b/smriti_vendor/mempalace/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..0a36c8d38b78d9ff9754461b05355f212814fc67 --- /dev/null +++ b/smriti_vendor/mempalace/cli.py @@ -0,0 +1,3 @@ +"""Compatibility wrapper for :mod:`smriti.cli`.""" + +from smriti.cli import * # noqa: F401,F403 diff --git a/smriti_vendor/mempalace/core.py b/smriti_vendor/mempalace/core.py new file mode 100644 index 0000000000000000000000000000000000000000..d588bb141bcb6dd63a23933e734690af5244ab40 --- /dev/null +++ b/smriti_vendor/mempalace/core.py @@ -0,0 +1,3 @@ +"""Compatibility wrapper for :mod:`smriti.core`.""" + +from smriti.core import * # noqa: F401,F403 diff --git a/smriti_vendor/mempalace/gifp.py b/smriti_vendor/mempalace/gifp.py new file mode 100644 index 0000000000000000000000000000000000000000..60019fa0b6e81d80282b7cecf29770c220fb0c76 --- /dev/null +++ b/smriti_vendor/mempalace/gifp.py @@ -0,0 +1,3 @@ +"""Compatibility wrapper for :mod:`smriti.gifp`.""" + +from smriti.gifp import * # noqa: F401,F403 diff --git a/smriti_vendor/mempalace/identity_fingerprint.py b/smriti_vendor/mempalace/identity_fingerprint.py new file mode 100644 index 0000000000000000000000000000000000000000..71ed7236b0872664cafad065aa5ba44105ec8920 --- /dev/null +++ b/smriti_vendor/mempalace/identity_fingerprint.py @@ -0,0 +1,3 @@ +"""Compatibility wrapper for :mod:`smriti.identity_fingerprint`.""" + +from smriti.identity_fingerprint import * # noqa: F401,F403 diff --git a/smriti_vendor/mempalace/knowledge_graph.py b/smriti_vendor/mempalace/knowledge_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..6a4b173040da8352b230947ca33a672ea1c1d240 --- /dev/null +++ b/smriti_vendor/mempalace/knowledge_graph.py @@ -0,0 +1,3 @@ +"""Compatibility wrapper for :mod:`smriti.knowledge_graph`.""" + +from smriti.knowledge_graph import * # noqa: F401,F403 diff --git a/smriti_vendor/mempalace/macp.py b/smriti_vendor/mempalace/macp.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e144ae45b0b242fae8404b758af9ef1e98a7c3 --- /dev/null +++ b/smriti_vendor/mempalace/macp.py @@ -0,0 +1,3 @@ +"""Compatibility wrapper for :mod:`smriti.macp`.""" + +from smriti.macp import * # noqa: F401,F403 diff --git a/smriti_vendor/mempalace/mem_palace.py b/smriti_vendor/mempalace/mem_palace.py new file mode 100644 index 0000000000000000000000000000000000000000..4e2b70c3fe4d41507e0de1867f1cea890dfcb827 --- /dev/null +++ b/smriti_vendor/mempalace/mem_palace.py @@ -0,0 +1,3 @@ +"""Compatibility wrapper for :mod:`smriti.mem_palace`.""" + +from smriti.mem_palace import * # noqa: F401,F403 diff --git a/smriti_vendor/mempalace/semantic_memory.py b/smriti_vendor/mempalace/semantic_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..124dedfef1e696c8231b082615b35747b4b78aa4 --- /dev/null +++ b/smriti_vendor/mempalace/semantic_memory.py @@ -0,0 +1,3 @@ +"""Compatibility wrapper for :mod:`smriti.semantic_memory`.""" + +from smriti.semantic_memory import * # noqa: F401,F403 diff --git a/smriti_vendor/smriti/__init__.py b/smriti_vendor/smriti/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93841731d08ed4ce835a0821890e0d22c15c4bb6 --- /dev/null +++ b/smriti_vendor/smriti/__init__.py @@ -0,0 +1,115 @@ +""" +Smriti AI — Inference-time memory framework for small language models. + +Smriti AI adds semantic memory, reasoning continuity, and identity +governance to any HuggingFace causal LM with zero fine-tuning. The name +comes from smriti, a Sanskrit term associated with memory and remembrance. + +Features: +- Semantic memory with FAISS-based retrieval +- Knowledge graph integration +- Embedding-based identity governance (GIFP v1.0) +- Multi-user support via API and CLI + +Quick start: + from smriti import MemPalaceLite, SmritiAILite + + memory = MemPalaceLite(retrieval_mode="semantic") + agent = SmritiAILite(model=model, tokenizer=tokenizer) + reply = agent.chat("My name is Jordan and I am a marine biologist.") +""" + +from .agent import BaselineGemma, GodelAILite, SmritiAILite +from .backends import ( + JsonBackend, + MemoryBackend, + MemoryCipher, + PostgresBackend, + RedisBackend, + SqliteBackend, + build_backend, +) +from .config import SmritiConfig, configure_environment_from_file, load_config, write_default_config +from .core import MemoryEntry, MemPalaceLite +from .gifp import GIFPLite +from .macp import MACPLite, ReasoningStep + +# New modules for enhanced functionality +try: + from .semantic_memory import ( + RetrievalResult, + SemanticMemory, + MemoryEntry as SemanticMemoryEntry, + ) +except ImportError: + RetrievalResult = None + SemanticMemory = None + SemanticMemoryEntry = None + +try: + from .knowledge_graph import GraphTriple, KnowledgeGraphMemory +except ImportError: + GraphTriple = None + KnowledgeGraphMemory = None + +try: + from .identity_fingerprint import IdentityCheck, IdentityFingerprint +except ImportError: + IdentityCheck = None + IdentityFingerprint = None + +__version__ = "0.3.1" +__author__ = "Alton Lee Wei Bin (creator35lwb)" + +__all__ = [ + "MemoryEntry", + "MemPalaceLite", + "ReasoningStep", + "MACPLite", + "GIFPLite", + "SmritiAILite", + "GodelAILite", + "BaselineGemma", + "MemoryBackend", + "MemoryCipher", + "JsonBackend", + "SqliteBackend", + "RedisBackend", + "PostgresBackend", + "build_backend", + "SmritiConfig", + "load_config", + "configure_environment_from_file", + "write_default_config", +] + +# Add new classes if available +if SemanticMemory is not None: + __all__.extend(["SemanticMemory", "SemanticMemoryEntry", "RetrievalResult"]) +if KnowledgeGraphMemory is not None: + __all__.extend(["KnowledgeGraphMemory", "GraphTriple"]) +if IdentityFingerprint is not None: + __all__.extend(["IdentityFingerprint", "IdentityCheck"]) +__all__.extend(["api_app", "create_app", "get_memory", "set_agent_factory", "set_memory_backend", "cli_main"]) + + +def __getattr__(name: str): + """Lazy optional API/CLI exports without double-registering Prometheus metrics.""" + + if name in {"api_app", "create_app", "get_memory", "set_agent_factory", "set_memory_backend"}: + from .api import app as api_app + from .api import create_app, get_memory, set_agent_factory, set_memory_backend + + values = { + "api_app": api_app, + "create_app": create_app, + "get_memory": get_memory, + "set_agent_factory": set_agent_factory, + "set_memory_backend": set_memory_backend, + } + return values[name] + if name == "cli_main": + from .cli import main as cli_main + + return cli_main + raise AttributeError(name) diff --git a/smriti_vendor/smriti/__main__.py b/smriti_vendor/smriti/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..15cec11be6896cfa68a6ddc77c4b26bc82db1181 --- /dev/null +++ b/smriti_vendor/smriti/__main__.py @@ -0,0 +1,7 @@ +"""Run the Smriti AI CLI with `python -m smriti`.""" + +from .cli import main + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/smriti_vendor/smriti/__pycache__/__init__.cpython-310.pyc b/smriti_vendor/smriti/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f78ae08a634875f7d30a915f9c04f26681d48335 Binary files /dev/null and b/smriti_vendor/smriti/__pycache__/__init__.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/__pycache__/__main__.cpython-310.pyc b/smriti_vendor/smriti/__pycache__/__main__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de8d0056ef5720898f4140bb12e36897f61fb342 Binary files /dev/null and b/smriti_vendor/smriti/__pycache__/__main__.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/__pycache__/agent.cpython-310.pyc b/smriti_vendor/smriti/__pycache__/agent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1c3608d048710489a67753d9da31690c9f85e0e Binary files /dev/null and b/smriti_vendor/smriti/__pycache__/agent.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/__pycache__/api.cpython-310.pyc b/smriti_vendor/smriti/__pycache__/api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d758a81b7ed82d957796ef4873dac46f91961a91 Binary files /dev/null and b/smriti_vendor/smriti/__pycache__/api.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/__pycache__/backends.cpython-310.pyc b/smriti_vendor/smriti/__pycache__/backends.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f8eb08740fd47326816fd6b98767b2d178f6fb7 Binary files /dev/null and b/smriti_vendor/smriti/__pycache__/backends.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/__pycache__/cli.cpython-310.pyc b/smriti_vendor/smriti/__pycache__/cli.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cffbfaa02f2a64a79a679c6bac530a0c42b87f1d Binary files /dev/null and b/smriti_vendor/smriti/__pycache__/cli.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/__pycache__/config.cpython-310.pyc b/smriti_vendor/smriti/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d5abf71bc44b405afeb56a20e139b0a9da1b889 Binary files /dev/null and b/smriti_vendor/smriti/__pycache__/config.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/__pycache__/core.cpython-310.pyc b/smriti_vendor/smriti/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e982d777d1a1972b5b5c2e34e9524624b322fd2 Binary files /dev/null and b/smriti_vendor/smriti/__pycache__/core.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/__pycache__/gifp.cpython-310.pyc b/smriti_vendor/smriti/__pycache__/gifp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4625ccc2885e2d5d2e9892abe1532a0d60de9fa Binary files /dev/null and b/smriti_vendor/smriti/__pycache__/gifp.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/__pycache__/identity_fingerprint.cpython-310.pyc b/smriti_vendor/smriti/__pycache__/identity_fingerprint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7913e851290da61511809e52e9584168610444a Binary files /dev/null and b/smriti_vendor/smriti/__pycache__/identity_fingerprint.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/__pycache__/knowledge_graph.cpython-310.pyc b/smriti_vendor/smriti/__pycache__/knowledge_graph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63f189b28fc7c6e9cac6cc83fb0c41246bdcbf7e Binary files /dev/null and b/smriti_vendor/smriti/__pycache__/knowledge_graph.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/__pycache__/macp.cpython-310.pyc b/smriti_vendor/smriti/__pycache__/macp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9057e87fd09582ca5f5d33189b34d69ae3e6455b Binary files /dev/null and b/smriti_vendor/smriti/__pycache__/macp.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/__pycache__/mem_palace.cpython-310.pyc b/smriti_vendor/smriti/__pycache__/mem_palace.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d921204b99741655e5ceecedd0c713252edf53aa Binary files /dev/null and b/smriti_vendor/smriti/__pycache__/mem_palace.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/__pycache__/semantic_memory.cpython-310.pyc b/smriti_vendor/smriti/__pycache__/semantic_memory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ff288782791437d7b9d528d29d1eaf3115be1c5 Binary files /dev/null and b/smriti_vendor/smriti/__pycache__/semantic_memory.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/agent.py b/smriti_vendor/smriti/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..b344fa26dd5da79a35cdff6ef3b90694622c0d04 --- /dev/null +++ b/smriti_vendor/smriti/agent.py @@ -0,0 +1,262 @@ +import os +from contextlib import nullcontext +from typing import Any, Dict, List, Optional, Tuple + +from .core import MemPalaceLite +from .identity_fingerprint import IdentityFingerprint +from .macp import MACPLite + +try: + import torch +except Exception: + torch = None + +try: + from transformers import GenerationConfig +except Exception: + GenerationConfig = None + + +class SmritiAILite: + """ + Model-agnostic SLM wrapper with semantic memory, graph memory, reasoning + continuity, and GIFP v1.0 identity governance. + + Pass any pre-loaded HuggingFace causal LM and tokenizer. + """ + + def __init__( + self, + model: Any, + tokenizer: Any, + memory_path: Optional[str] = None, + retrieval_mode: str = "semantic", + session_id: str = "default", + topic_id: str = "general", + memory: Optional[MemPalaceLite] = None, + identity: Optional[IdentityFingerprint] = None, + auto_device: bool = True, + ): + self.model = model + self.tokenizer = tokenizer + self.session_id = session_id + self.topic_id = topic_id + + if memory is not None: + self.memory = memory + elif memory_path and os.path.exists(memory_path): + self.memory = MemPalaceLite.load(memory_path, retrieval_mode=retrieval_mode) + else: + self.memory = MemPalaceLite( + retrieval_mode=retrieval_mode, + session_id=session_id, + topic_id=topic_id, + ) + + self.continuity = MACPLite() + self.identity = identity or IdentityFingerprint( + role="helpful AI assistant with persistent memory" + ) + self.identity.set_constraints( + [ + "Always be helpful and accurate", + "Reference previous context when relevant", + "Maintain logical consistency across turns", + "Acknowledge uncertainty when present", + ] + ) + self.device, self.autocast_dtype = configure_inference_device() + if auto_device: + self._move_model_to_best_device() + + def build_prompt(self, user_input: str) -> str: + identity = self.identity.get_identity_prompt() + ctx = self.memory.get_context( + query=user_input, + session_id=self.session_id, + topic_id=self.topic_id, + ) + if ctx: + return identity + "\n" + ctx + "\n\n" + user_input + return identity + "\n" + user_input + + def _generate(self, prompt: str, max_tokens: int = 256) -> str: + if torch is None or GenerationConfig is None: + raise RuntimeError("torch and transformers are required for model generation.") + + messages = [{"role": "user", "content": prompt}] + try: + formatted = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except Exception: + formatted = prompt + + inputs = self.tokenizer( + formatted, + return_tensors="pt", + truncation=True, + max_length=2048, + ) + model_device = _model_device(self.model) or self.device + inputs = {key: value.to(model_device) for key, value in inputs.items()} + cfg = GenerationConfig( + max_new_tokens=max_tokens, + temperature=0.7, + top_p=0.9, + do_sample=True, + pad_token_id=getattr(self.tokenizer, "eos_token_id", None), + ) + with torch.inference_mode(), _autocast_context(model_device, self.autocast_dtype): + out = self.model.generate(**inputs, generation_config=cfg) + return self.tokenizer.decode( + out[0, inputs["input_ids"].shape[1] :].detach().cpu(), + skip_special_tokens=True, + ).strip() + + def chat(self, user_input: str, refine: bool = False) -> str: + self.continuity.start_chain(user_input) + self.identity.observe_user_input(user_input) + prompt = self.build_prompt(user_input) + response = self._generate(prompt) + + context = self.memory.get_context( + query=user_input, + session_id=self.session_id, + topic_id=self.topic_id, + ) + response, identity_check = self.identity.ensure_aligned( + response, + self._generate, + user_input=user_input, + context=context, + ) + if refine and identity_check.consistency_score < 0.5: + response = self.identity.refinement_pass( + self._generate, + response, + user_input=user_input, + context=context, + ) + identity_check = self.identity.evaluate_output(response) + + self.continuity.add_step( + user_input, + response, + identity_check.consistency_score, + "continue" if identity_check.consistency_score > 0.7 else "refine", + ) + for fact in self.memory.extract_facts(response, user_input=user_input): + self.memory.add_fact(fact, session_id=self.session_id, topic_id=self.topic_id) + self.memory.add_to_history("User: " + user_input, "user_input") + self.memory.add_to_history( + "Assistant: " + response[:200], "assistant_output" + ) + self.identity.record_behavior(response) + return response + + def save_memory(self, path: str): + self.memory.save(path) + + def load_memory(self, path: str): + self.memory = MemPalaceLite.load(path, retrieval_mode=self.memory.retrieval_mode) + + def get_memory_state(self) -> Dict: + return self.memory.to_dict() + + def get_reasoning_chain(self) -> str: + return self.continuity.get_chain_summary() + + def _move_model_to_best_device(self) -> None: + if torch is None or self.device is None or str(self.device) == "cpu": + return + try: + current = _model_device(self.model) + if current is not None and str(current).startswith("cuda"): + return + self.model.to(self.device) + except Exception: + pass + + +class BaselineGemma: + """Plain causal LM with no memory, no identity layer, no continuity.""" + + def __init__(self, model: Any, tokenizer: Any, auto_device: bool = True): + self.model = model + self.tokenizer = tokenizer + self._history: List[str] = [] + self.device, self.autocast_dtype = configure_inference_device() + if auto_device and torch is not None and str(self.device) != "cpu": + try: + self.model.to(self.device) + except Exception: + pass + + def chat(self, user_input: str) -> str: + if torch is None or GenerationConfig is None: + raise RuntimeError("torch and transformers are required for model generation.") + + messages = [{"role": "user", "content": user_input}] + try: + prompt = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except Exception: + prompt = user_input + inputs = self.tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=2048, + ) + model_device = _model_device(self.model) or self.device + inputs = {key: value.to(model_device) for key, value in inputs.items()} + cfg = GenerationConfig( + max_new_tokens=getattr(self, "max_new_tokens", 256), + temperature=0.7, + top_p=0.9, + do_sample=True, + pad_token_id=getattr(self.tokenizer, "eos_token_id", None), + ) + with torch.inference_mode(), _autocast_context(model_device, self.autocast_dtype): + out = self.model.generate(**inputs, generation_config=cfg) + response = self.tokenizer.decode( + out[0, inputs["input_ids"].shape[1] :].detach().cpu(), + skip_special_tokens=True, + ).strip() + self._history.extend(["User: " + user_input, "Assistant: " + response]) + return response + + def reset(self): + self._history = [] + + +def configure_inference_device() -> Tuple[Any, Any]: + """Return the preferred torch device and mixed-precision dtype.""" + + if torch is None: + return "cpu", None + if torch.cuda.is_available(): + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + return torch.device("cuda"), dtype + return torch.device("cpu"), torch.float32 + + +def _model_device(model: Any) -> Any: + try: + return next(model.parameters()).device + except Exception: + return None + + +def _autocast_context(device: Any, dtype: Any): + if torch is None or dtype is None: + return nullcontext() + if str(device).startswith("cuda"): + return torch.autocast(device_type="cuda", dtype=dtype) + return nullcontext() + + +# Backwards compatibility for existing user code. +GodelAILite = SmritiAILite diff --git a/smriti_vendor/smriti/api.py b/smriti_vendor/smriti/api.py new file mode 100644 index 0000000000000000000000000000000000000000..fde07842c5f984740888a7a5e0de7a5a0560e956 --- /dev/null +++ b/smriti_vendor/smriti/api.py @@ -0,0 +1,538 @@ +"""FastAPI layer for multi-user, multi-agent Smriti AI memory access.""" + +from __future__ import annotations + +import json +import logging +import os +import time +import uuid +from contextlib import contextmanager +from threading import RLock +from typing import Any, Callable, Dict, Iterator, List, Optional + +from fastapi import FastAPI, HTTPException, Request, Response +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +from prometheus_client import CONTENT_TYPE_LATEST, Counter, Gauge, Histogram, generate_latest + +from .backends import MemoryBackend, build_backend +from .config import configure_environment_from_file, load_config +from .core import MemPalaceLite + + +AgentFactory = Callable[..., Any] + +USER_MEMORIES: Dict[str, MemPalaceLite] = {} +MEMORY_LOCK = RLock() +MEMORY_BACKEND: Optional[MemoryBackend] = None +AGENT_FACTORY: Optional[AgentFactory] = None +LOGGER = logging.getLogger("smriti.api") +LOGGER.setLevel(logging.INFO) +if not LOGGER.handlers: + _handler = logging.StreamHandler() + _handler.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s")) + LOGGER.addHandler(_handler) +LOGGER.propagate = False + +HTTP_REQUESTS = Counter( + "smriti_http_requests_total", + "Total HTTP requests handled by the Smriti AI API.", + ("method", "path", "status"), +) +HTTP_ERRORS = Counter( + "smriti_http_errors_total", + "Total HTTP requests that completed with status >= 500.", + ("method", "path"), +) +HTTP_LATENCY = Histogram( + "smriti_http_request_latency_seconds", + "End-to-end HTTP request latency.", + ("method", "path"), +) +RETRIEVAL_LATENCY = Histogram( + "smriti_retrieval_latency_seconds", + "Memory retrieval latency for chat requests.", + ("retrieval_mode",), +) +TOKEN_USAGE = Counter( + "smriti_tokens_total", + "Approximate whitespace-token count observed by the API.", + ("user_id", "agent_id"), +) +USER_MEMORY_COUNT = Gauge( + "smriti_user_memories", + "Number of in-memory user memory stores.", +) +USER_MEMORY_BYTES = Gauge( + "smriti_user_memory_bytes", + "Approximate serialized memory size by user.", + ("user_id",), +) + + +class ChatRequest(BaseModel): + user_id: str + message: str + topic_id: str = "general" + agent_id: str = "executor" + retrieval_mode: str = "semantic" + + +class ChatResponse(BaseModel): + user_id: str + agent_id: str + topic_id: str + response: str + retrieved_context: str + memory: Dict[str, Any] + + +class MemoryLoadRequest(BaseModel): + user_id: str + memory: Optional[Dict[str, Any]] = None + path: Optional[str] = None + retrieval_mode: str = "semantic" + + +class MemorySaveRequest(BaseModel): + user_id: str + path: Optional[str] = None + + +class MemoryDeleteRequest(BaseModel): + user_id: str + path: Optional[str] = None + + +class GraphQueryRequest(BaseModel): + user_id: str + query_entity: str + topic_id: Optional[str] = None + depth: int = Field(default=1, ge=1, le=4) + + +def set_agent_factory(factory: Optional[AgentFactory]) -> None: + """ + Register a callable that returns a configured model agent. + + The callable receives `user_id`, `memory`, `topic_id`, and `agent_id`. + When no factory is configured, `/chat` runs in memory-only mode. + """ + + global AGENT_FACTORY + AGENT_FACTORY = factory + + +def set_memory_backend(backend: Optional[MemoryBackend]) -> None: + """Override the configured persistence backend for tests or deployments.""" + + global MEMORY_BACKEND + MEMORY_BACKEND = backend + + +def get_memory_backend() -> MemoryBackend: + """Return the configured durable backend, constructing it lazily from env.""" + + global MEMORY_BACKEND + if MEMORY_BACKEND is None: + configure_environment_from_file() + MEMORY_BACKEND = build_backend() + return MEMORY_BACKEND + + +def get_memory(user_id: str, retrieval_mode: str = "semantic") -> MemPalaceLite: + with MEMORY_LOCK: + if user_id not in USER_MEMORIES: + state = None + try: + state = get_memory_backend().load(user_id) + except Exception: + LOGGER.exception("Durable memory load failed; starting empty memory") + if state: + memory = MemPalaceLite.from_dict(state, retrieval_mode=retrieval_mode) + memory.session_id = user_id + else: + memory = MemPalaceLite( + retrieval_mode=retrieval_mode, + session_id=user_id, + ) + USER_MEMORIES[user_id] = memory + USER_MEMORY_COUNT.set(len(USER_MEMORIES)) + return USER_MEMORIES[user_id] + + +def create_app() -> FastAPI: + config = configure_environment_from_file() + app = FastAPI( + title="Smriti AI API", + version="0.3.1", + description="Semantic memory, knowledge graph and identity governance API.", + ) + app.add_middleware( + CORSMiddleware, + allow_origins=config.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.middleware("http") + async def request_observability(request: Request, call_next: Callable[..., Any]) -> Response: + request_id = request.headers.get("x-request-id", str(uuid.uuid4())) + path = request.url.path + start = time.perf_counter() + status_code = 500 + try: + _enforce_api_key(request) + response = await call_next(request) + status_code = response.status_code + except HTTPException as exc: + status_code = exc.status_code + response = Response( + content=f'{{"detail":"{exc.detail}"}}', + status_code=exc.status_code, + media_type="application/json", + ) + except Exception: + LOGGER.exception( + "Unhandled API request failure", + extra={"request_id": request_id, "path": path}, + ) + response = Response( + content='{"detail":"Internal server error"}', + status_code=500, + media_type="application/json", + ) + duration = time.perf_counter() - start + HTTP_LATENCY.labels(request.method, path).observe(duration) + HTTP_REQUESTS.labels(request.method, path, str(status_code)).inc() + if status_code >= 500: + HTTP_ERRORS.labels(request.method, path).inc() + USER_MEMORY_COUNT.set(len(USER_MEMORIES)) + response.headers["x-request-id"] = request_id + LOGGER.info( + "request completed request_id=%s method=%s path=%s status=%s duration_s=%.6f", + request_id, + request.method, + path, + status_code, + duration, + ) + return response + + @app.get("/health") + def health() -> Dict[str, Any]: + return {"status": "ok", "users": len(USER_MEMORIES)} + + @app.get("/metrics") + def metrics() -> Response: + return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST) + + @app.post("/chat", response_model=ChatResponse) + def chat(request: ChatRequest) -> ChatResponse: + memory = get_memory(request.user_id, retrieval_mode=request.retrieval_mode) + memory.retrieval_mode = request.retrieval_mode + memory.topic_id = request.topic_id + context, degraded, warnings = _safe_get_context( + memory, + query=request.message, + session_id=request.user_id, + topic_id=request.topic_id, + retrieval_mode=request.retrieval_mode, + ) + + if AGENT_FACTORY is not None: + agent = _build_agent( + AGENT_FACTORY, + user_id=request.user_id, + memory=memory, + topic_id=request.topic_id, + agent_id=request.agent_id, + ) + with MEMORY_LOCK: + try: + response = agent.chat(request.message) + except Exception as exc: + LOGGER.exception("Agent factory chat failed") + degraded = True + warnings.append(f"agent_failure:{exc.__class__.__name__}") + response = _memory_only_response(context) + state = memory.to_dict() + else: + response = _memory_only_response(context) + with MEMORY_LOCK: + _safe_update_memory( + memory, + request.message, + response, + request.user_id, + request.topic_id, + warnings, + ) + state = memory.to_dict() + _persist_if_configured(request.user_id, state, warnings) + + TOKEN_USAGE.labels(request.user_id, request.agent_id).inc( + _count_tokens(request.message) + _count_tokens(response) + ) + state["_degraded"] = degraded + state["_warnings"] = warnings + return ChatResponse( + user_id=request.user_id, + agent_id=request.agent_id, + topic_id=request.topic_id, + response=response, + retrieved_context=context, + memory=state, + ) + + @app.post("/memory/load") + def load_memory(request: MemoryLoadRequest) -> Dict[str, Any]: + with MEMORY_LOCK: + if request.path: + memory = MemPalaceLite.load( + request.path, + retrieval_mode=request.retrieval_mode, + ) + elif request.memory: + memory = MemPalaceLite.from_dict( + request.memory or {}, + retrieval_mode=request.retrieval_mode, + ) + else: + state = get_memory_backend().load(request.user_id) + if state is None: + raise HTTPException(status_code=404, detail="No memory found for user.") + memory = MemPalaceLite.from_dict( + state, + retrieval_mode=request.retrieval_mode, + ) + memory.session_id = request.user_id + USER_MEMORIES[request.user_id] = memory + return memory.to_dict() + + @app.post("/memory/save") + def save_memory(request: MemorySaveRequest) -> Dict[str, Any]: + memory = get_memory(request.user_id) + with MEMORY_LOCK: + if request.path: + memory.save(request.path) + state = memory.to_dict() + if not request.path: + get_memory_backend().save(request.user_id, state) + _observe_memory_size(request.user_id, state) + return state + + @app.post("/memory/delete") + def delete_memory(request: MemoryDeleteRequest) -> Dict[str, Any]: + with MEMORY_LOCK: + existed = USER_MEMORIES.pop(request.user_id, None) is not None + USER_MEMORY_COUNT.set(len(USER_MEMORIES)) + deleted_file = False + if request.path and os.path.exists(request.path): + os.remove(request.path) + deleted_file = True + deleted_backend = False + try: + deleted_backend = get_memory_backend().delete_user(request.user_id) + except Exception: + LOGGER.exception("Durable memory deletion failed") + try: + USER_MEMORY_BYTES.remove(request.user_id) + except Exception: + pass + return { + "user_id": request.user_id, + "deleted_memory": existed, + "deleted_file": deleted_file, + "deleted_backend": deleted_backend, + "remaining_users": len(USER_MEMORIES), + } + + @app.post("/graph/query") + def graph_query(request: GraphQueryRequest) -> Dict[str, Any]: + memory = get_memory(request.user_id) + try: + triples = memory.knowledge_graph.query_graph( + request.user_id, + request.query_entity, + depth=request.depth, + topic_id=request.topic_id, + ) + degraded = False + warnings: List[str] = [] + except Exception as exc: + LOGGER.exception("Knowledge graph query failed") + triples = [] + degraded = True + warnings = [f"knowledge_graph_failure:{exc.__class__.__name__}"] + return { + "user_id": request.user_id, + "query_entity": request.query_entity, + "triples": [triple.__dict__ for triple in triples], + "facts": memory.knowledge_graph.triples_to_text(triples), + "degraded": degraded, + "warnings": warnings, + } + + return app + + +def _build_agent(factory: AgentFactory, **kwargs: Any) -> Any: + try: + return factory(**kwargs) + except TypeError: + return factory(kwargs["memory"]) + + +def _memory_only_response(context: str) -> str: + if context: + bullets = [] + for line in context.splitlines(): + cleaned = line.strip() + if cleaned.startswith("* "): + bullets.append(cleaned[2:].strip()) + if bullets: + rendered = "\n".join(f"- {fact}" for fact in bullets[:5]) + return f"Memory updated. I found relevant context:\n{rendered}" + return "Memory updated. Relevant context is available for the configured model." + return "Memory updated. No prior relevant context was found." + + +def _safe_get_context( + memory: MemPalaceLite, + query: str, + session_id: str, + topic_id: str, + retrieval_mode: str, +) -> tuple[str, bool, List[str]]: + warnings: List[str] = [] + with _observe_retrieval(retrieval_mode): + try: + return ( + memory.get_context( + query=query, + session_id=session_id, + topic_id=topic_id, + ), + False, + warnings, + ) + except Exception as exc: + LOGGER.exception("Primary retrieval failed; degrading to TF-IDF/no-graph context") + warnings.append(f"primary_retrieval_failure:{exc.__class__.__name__}") + + original_mode = memory.retrieval_mode + try: + memory.retrieval_mode = "tfidf" + with _observe_retrieval("tfidf"): + context = memory.get_context( + query=query, + session_id=session_id, + topic_id=topic_id, + include_graph=False, + ) + warnings.append("degraded_to_tfidf") + return context, True, warnings + except Exception as exc: + LOGGER.exception("Fallback TF-IDF retrieval failed") + warnings.append(f"fallback_retrieval_failure:{exc.__class__.__name__}") + return "", True, warnings + finally: + memory.retrieval_mode = original_mode + + +def _safe_update_memory( + memory: MemPalaceLite, + user_message: str, + response: str, + session_id: str, + topic_id: str, + warnings: List[str], +) -> None: + try: + for fact in memory.extract_facts("", user_input=user_message): + try: + memory.add_fact(fact, session_id=session_id, topic_id=topic_id) + except Exception as exc: + LOGGER.exception("Fact storage failed") + warnings.append(f"fact_storage_failure:{exc.__class__.__name__}") + memory.add_to_history("User: " + user_message, "user_input") + memory.add_to_history("Assistant: " + response[:200], "assistant_output") + except Exception as exc: + LOGGER.exception("Memory update failed") + warnings.append(f"memory_update_failure:{exc.__class__.__name__}") + + +def _persist_if_configured(user_id: str, state: Dict[str, Any], warnings: List[str]) -> None: + if os.getenv("SMRITI_AUTOSAVE", "0").lower() not in {"1", "true", "yes"}: + _observe_memory_size(user_id, state) + return + try: + get_memory_backend().save(user_id, state) + _observe_memory_size(user_id, state) + except Exception as exc: + LOGGER.exception("Durable memory autosave failed") + warnings.append(f"autosave_failure:{exc.__class__.__name__}") + + +def _observe_memory_size(user_id: str, state: Dict[str, Any]) -> None: + try: + USER_MEMORY_BYTES.labels(user_id).set(len(json.dumps(state))) + except Exception: + pass + + +@contextmanager +def _observe_retrieval(retrieval_mode: str) -> Iterator[None]: + start = time.perf_counter() + try: + yield + finally: + RETRIEVAL_LATENCY.labels(retrieval_mode).observe(time.perf_counter() - start) + + +def _count_tokens(text: str) -> int: + return max(1, len(text.split())) if text else 0 + + +def _enforce_api_key(request: Request) -> None: + expected = os.getenv("SMRITI_API_KEY") + if not expected: + return + if request.url.path in {"/health", "/metrics", "/docs", "/openapi.json"}: + return + supplied = request.headers.get("x-api-key") + if supplied != expected: + raise HTTPException(status_code=401, detail="Invalid or missing API key.") + + +app = create_app() + + +def main(argv: Optional[List[str]] = None) -> None: + """Run the API with `python -m smriti.api` or the `smriti-api` entry point.""" + + import argparse + import uvicorn + + parser = argparse.ArgumentParser(description="Run the Smriti AI FastAPI service.") + parser.add_argument("--config", help="Path to config.yaml.") + parser.add_argument("--host", help="Bind host. Defaults to config or SMRITI_HOST.") + parser.add_argument("--port", type=int, help="Bind port. Defaults to config or SMRITI_PORT.") + parser.add_argument("--reload", action="store_true", help="Enable Uvicorn reload mode.") + args = parser.parse_args(argv) + if args.config: + os.environ["SMRITI_CONFIG_PATH"] = args.config + config = load_config() + uvicorn.run( + "smriti.api:app" if args.reload else app, + host=args.host or os.getenv("SMRITI_HOST", config.host), + port=args.port or int(os.getenv("SMRITI_PORT", config.port)), + reload=args.reload, + ) + + +if __name__ == "__main__": + main() diff --git a/smriti_vendor/smriti/backends.py b/smriti_vendor/smriti/backends.py new file mode 100644 index 0000000000000000000000000000000000000000..035669cdf621b427c81ae5550c77561d9ae13257 --- /dev/null +++ b/smriti_vendor/smriti/backends.py @@ -0,0 +1,494 @@ +"""Durable memory backends for Smriti AI. + +Backends persist complete user memory blobs and also expose a minimal entry API +for tools that want to store/retrieve lightweight facts without instantiating the +full runtime. Optional encryption is applied at the blob boundary so JSON, SQL, +Redis, and Postgres stores share the same privacy behavior. +""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import os +import re +import sqlite3 +import time +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Dict, List, Optional + + +class MemoryBackend(ABC): + """Abstract persistence contract for user-isolated Smriti AI memory.""" + + @abstractmethod + def load(self, user_id: str) -> Optional[Dict[str, Any]]: + """Load a complete memory state for one user, or None if absent.""" + + @abstractmethod + def save(self, user_id: str, memory: Dict[str, Any]) -> None: + """Persist a complete memory state for one user.""" + + @abstractmethod + def add_entry( + self, + user_id: str, + session_id: str, + topic_id: str, + text: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Persist one lightweight fact/entry for a user/session/topic.""" + + @abstractmethod + def retrieve( + self, + user_id: str, + session_id: Optional[str] = None, + topic_id: Optional[str] = None, + query: str = "", + k: int = 5, + ) -> List[Dict[str, Any]]: + """Retrieve lightweight entries scoped to a user/session/topic.""" + + @abstractmethod + def delete_user(self, user_id: str) -> bool: + """Delete all memory owned by one user. Return whether anything existed.""" + + +class MemoryCipher: + """Optional symmetric encryption wrapper using Fernet when configured.""" + + def __init__(self, secret: Optional[str] = None): + self.secret = secret or os.getenv("SMRITI_MEMORY_KEY") + self._fernet = None + if self.secret: + try: + from cryptography.fernet import Fernet + except Exception as exc: # pragma: no cover - depends on optional install. + raise RuntimeError( + "SMRITI_MEMORY_KEY is set, but cryptography is not installed. " + "Install smriti-ai[security] or smriti-ai[full]." + ) from exc + self._fernet = Fernet(_fernet_key(self.secret)) + + @property + def enabled(self) -> bool: + return self._fernet is not None + + def wrap(self, payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._fernet: + return {"encrypted": False, "payload": payload} + raw = json.dumps(payload, sort_keys=True).encode("utf-8") + return { + "encrypted": True, + "algorithm": "fernet-sha256-derived-key", + "payload": self._fernet.encrypt(raw).decode("utf-8"), + } + + def unwrap(self, wrapper: Dict[str, Any]) -> Dict[str, Any]: + if not wrapper.get("encrypted"): + return dict(wrapper.get("payload", {})) + if not self._fernet: + raise RuntimeError("Memory blob is encrypted but SMRITI_MEMORY_KEY is not configured.") + decrypted = self._fernet.decrypt(wrapper["payload"].encode("utf-8")) + return json.loads(decrypted.decode("utf-8")) + + +def build_backend(kind: Optional[str] = None, **kwargs: Any) -> MemoryBackend: + """Construct a backend from an explicit kind or SMRITI_MEMORY_BACKEND.""" + + selected = (kind or os.getenv("SMRITI_MEMORY_BACKEND") or "json").lower() + if selected == "json": + return JsonBackend(root=kwargs.get("root") or os.getenv("SMRITI_MEMORY_DIR", "data/memory")) + if selected == "sqlite": + return SqliteBackend(path=kwargs.get("path") or os.getenv("SMRITI_SQLITE_PATH", "data/smriti_memory.sqlite3")) + if selected == "redis": + return RedisBackend(url=kwargs.get("url") or os.getenv("SMRITI_REDIS_URL", "redis://localhost:6379/0")) + if selected in {"postgres", "postgresql"}: + return PostgresBackend(dsn=kwargs.get("dsn") or os.getenv("SMRITI_POSTGRES_DSN", "")) + raise ValueError("SMRITI_MEMORY_BACKEND must be one of: json, sqlite, redis, postgres.") + + +class JsonBackend(MemoryBackend): + """File-per-user JSON backend. This preserves the original local behavior.""" + + def __init__(self, root: str | Path = "data/memory", cipher: Optional[MemoryCipher] = None): + self.root = Path(root) + self.cipher = cipher or MemoryCipher() + + def load(self, user_id: str) -> Optional[Dict[str, Any]]: + path = self._path(user_id) + if not path.exists(): + return None + return self.cipher.unwrap(json.loads(path.read_text(encoding="utf-8"))) + + def save(self, user_id: str, memory: Dict[str, Any]) -> None: + self.root.mkdir(parents=True, exist_ok=True) + self._path(user_id).write_text( + json.dumps(self.cipher.wrap(memory), indent=2), + encoding="utf-8", + ) + + def add_entry( + self, + user_id: str, + session_id: str, + topic_id: str, + text: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + state = self.load(user_id) or {"backend_entries": []} + state.setdefault("backend_entries", []).append(_entry(session_id, topic_id, text, metadata)) + self.save(user_id, state) + + def retrieve( + self, + user_id: str, + session_id: Optional[str] = None, + topic_id: Optional[str] = None, + query: str = "", + k: int = 5, + ) -> List[Dict[str, Any]]: + state = self.load(user_id) or {} + return _rank_entries(state.get("backend_entries", []), session_id, topic_id, query, k) + + def delete_user(self, user_id: str) -> bool: + path = self._path(user_id) + existed = path.exists() + if existed: + path.unlink() + return existed + + def _path(self, user_id: str) -> Path: + return self.root / f"{_safe_id(user_id)}.json" + + +class SqliteBackend(MemoryBackend): + """SQLite backend for local durable multi-user memory.""" + + def __init__(self, path: str | Path = "data/smriti_memory.sqlite3", cipher: Optional[MemoryCipher] = None): + self.path = Path(path) + self.cipher = cipher or MemoryCipher() + self._init_schema() + + def load(self, user_id: str) -> Optional[Dict[str, Any]]: + with self._connect() as conn: + row = conn.execute("SELECT payload FROM user_memory WHERE user_id = ?", (user_id,)).fetchone() + if not row: + return None + return self.cipher.unwrap(json.loads(row[0])) + + def save(self, user_id: str, memory: Dict[str, Any]) -> None: + payload = json.dumps(self.cipher.wrap(memory)) + with self._connect() as conn: + conn.execute( + """ + INSERT INTO user_memory(user_id, payload, updated_at) + VALUES(?, ?, ?) + ON CONFLICT(user_id) DO UPDATE SET payload=excluded.payload, updated_at=excluded.updated_at + """, + (user_id, payload, time.time()), + ) + + def add_entry( + self, + user_id: str, + session_id: str, + topic_id: str, + text: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + with self._connect() as conn: + conn.execute( + """ + INSERT INTO memory_entries(user_id, session_id, topic_id, text, metadata, created_at) + VALUES(?, ?, ?, ?, ?, ?) + """, + (user_id, session_id, topic_id, text, json.dumps(metadata or {}), time.time()), + ) + + def retrieve( + self, + user_id: str, + session_id: Optional[str] = None, + topic_id: Optional[str] = None, + query: str = "", + k: int = 5, + ) -> List[Dict[str, Any]]: + clauses = ["user_id = ?"] + params: List[Any] = [user_id] + if session_id: + clauses.append("session_id = ?") + params.append(session_id) + if topic_id: + clauses.append("topic_id = ?") + params.append(topic_id) + params.append(max(1, k * 5)) + sql = f""" + SELECT session_id, topic_id, text, metadata, created_at + FROM memory_entries + WHERE {' AND '.join(clauses)} + ORDER BY created_at DESC + LIMIT ? + """ + with self._connect() as conn: + rows = conn.execute(sql, params).fetchall() + entries = [ + { + "session_id": row[0], + "topic_id": row[1], + "text": row[2], + "metadata": json.loads(row[3] or "{}"), + "created_at": row[4], + } + for row in rows + ] + return _rank_entries(entries, session_id, topic_id, query, k) + + def delete_user(self, user_id: str) -> bool: + with self._connect() as conn: + before = conn.total_changes + conn.execute("DELETE FROM user_memory WHERE user_id = ?", (user_id,)) + conn.execute("DELETE FROM memory_entries WHERE user_id = ?", (user_id,)) + return conn.total_changes > before + + def _connect(self) -> sqlite3.Connection: + self.path.parent.mkdir(parents=True, exist_ok=True) + return sqlite3.connect(self.path) + + def _init_schema(self) -> None: + with self._connect() as conn: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS user_memory( + user_id TEXT PRIMARY KEY, + payload TEXT NOT NULL, + updated_at REAL NOT NULL + ) + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS memory_entries( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, + session_id TEXT NOT NULL, + topic_id TEXT NOT NULL, + text TEXT NOT NULL, + metadata TEXT NOT NULL, + created_at REAL NOT NULL + ) + """ + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_entries_user_session_topic ON memory_entries(user_id, session_id, topic_id, created_at)" + ) + + +class RedisBackend(MemoryBackend): # pragma: no cover - requires external Redis service. + """Redis backend using string payloads and per-user entry lists.""" + + def __init__(self, url: str = "redis://localhost:6379/0", cipher: Optional[MemoryCipher] = None): + try: + import redis + except Exception as exc: # pragma: no cover - optional dependency. + raise RuntimeError("Install redis to use RedisBackend: pip install smriti-ai[backends]") from exc + self.client = redis.Redis.from_url(url, decode_responses=True) + self.cipher = cipher or MemoryCipher() + + def load(self, user_id: str) -> Optional[Dict[str, Any]]: + raw = self.client.get(self._payload_key(user_id)) + if not raw: + return None + return self.cipher.unwrap(json.loads(raw)) + + def save(self, user_id: str, memory: Dict[str, Any]) -> None: + self.client.set(self._payload_key(user_id), json.dumps(self.cipher.wrap(memory))) + + def add_entry(self, user_id: str, session_id: str, topic_id: str, text: str, metadata: Optional[Dict[str, Any]] = None) -> None: + self.client.lpush(self._entries_key(user_id), json.dumps(_entry(session_id, topic_id, text, metadata))) + + def retrieve(self, user_id: str, session_id: Optional[str] = None, topic_id: Optional[str] = None, query: str = "", k: int = 5) -> List[Dict[str, Any]]: + raw_entries = self.client.lrange(self._entries_key(user_id), 0, max(0, k * 5 - 1)) + entries = [json.loads(item) for item in raw_entries] + return _rank_entries(entries, session_id, topic_id, query, k) + + def delete_user(self, user_id: str) -> bool: + return bool(self.client.delete(self._payload_key(user_id), self._entries_key(user_id))) + + def _payload_key(self, user_id: str) -> str: + return f"smriti:user:{_safe_id(user_id)}:payload" + + def _entries_key(self, user_id: str) -> str: + return f"smriti:user:{_safe_id(user_id)}:entries" + + +class PostgresBackend(MemoryBackend): # pragma: no cover - requires external Postgres service. + """Postgres backend using psycopg2 and indexed user/session/topic tables.""" + + def __init__(self, dsn: str, cipher: Optional[MemoryCipher] = None): + if not dsn: + raise ValueError("SMRITI_POSTGRES_DSN is required for PostgresBackend.") + try: + import psycopg2 + except Exception as exc: # pragma: no cover - optional dependency. + raise RuntimeError("Install psycopg2-binary to use PostgresBackend: pip install smriti-ai[backends]") from exc + self._psycopg2 = psycopg2 + self.dsn = dsn + self.cipher = cipher or MemoryCipher() + self._init_schema() + + def load(self, user_id: str) -> Optional[Dict[str, Any]]: + with self._connect() as conn, conn.cursor() as cur: + cur.execute("SELECT payload FROM user_memory WHERE user_id = %s", (user_id,)) + row = cur.fetchone() + if not row: + return None + return self.cipher.unwrap(row[0] if isinstance(row[0], dict) else json.loads(row[0])) + + def save(self, user_id: str, memory: Dict[str, Any]) -> None: + payload = json.dumps(self.cipher.wrap(memory)) + with self._connect() as conn, conn.cursor() as cur: + cur.execute( + """ + INSERT INTO user_memory(user_id, payload, updated_at) + VALUES(%s, %s::jsonb, NOW()) + ON CONFLICT(user_id) DO UPDATE SET payload=excluded.payload, updated_at=excluded.updated_at + """, + (user_id, payload), + ) + + def add_entry(self, user_id: str, session_id: str, topic_id: str, text: str, metadata: Optional[Dict[str, Any]] = None) -> None: + with self._connect() as conn, conn.cursor() as cur: + cur.execute( + """ + INSERT INTO memory_entries(user_id, session_id, topic_id, text, metadata) + VALUES(%s, %s, %s, %s, %s::jsonb) + """, + (user_id, session_id, topic_id, text, json.dumps(metadata or {})), + ) + + def retrieve(self, user_id: str, session_id: Optional[str] = None, topic_id: Optional[str] = None, query: str = "", k: int = 5) -> List[Dict[str, Any]]: + clauses = ["user_id = %s"] + params: List[Any] = [user_id] + if session_id: + clauses.append("session_id = %s") + params.append(session_id) + if topic_id: + clauses.append("topic_id = %s") + params.append(topic_id) + params.append(max(1, k * 5)) + sql = f""" + SELECT session_id, topic_id, text, metadata, EXTRACT(EPOCH FROM created_at) + FROM memory_entries + WHERE {' AND '.join(clauses)} + ORDER BY created_at DESC + LIMIT %s + """ + with self._connect() as conn, conn.cursor() as cur: + cur.execute(sql, params) + rows = cur.fetchall() + entries = [ + { + "session_id": row[0], + "topic_id": row[1], + "text": row[2], + "metadata": row[3] or {}, + "created_at": float(row[4]), + } + for row in rows + ] + return _rank_entries(entries, session_id, topic_id, query, k) + + def delete_user(self, user_id: str) -> bool: + with self._connect() as conn, conn.cursor() as cur: + cur.execute("DELETE FROM user_memory WHERE user_id = %s", (user_id,)) + memory_deleted = cur.rowcount + cur.execute("DELETE FROM memory_entries WHERE user_id = %s", (user_id,)) + entries_deleted = cur.rowcount + return bool(memory_deleted or entries_deleted) + + def _connect(self): + return self._psycopg2.connect(self.dsn) + + def _init_schema(self) -> None: + with self._connect() as conn, conn.cursor() as cur: + cur.execute( + """ + CREATE TABLE IF NOT EXISTS user_memory( + user_id TEXT PRIMARY KEY, + payload JSONB NOT NULL, + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ) + """ + ) + cur.execute( + """ + CREATE TABLE IF NOT EXISTS memory_entries( + id BIGSERIAL PRIMARY KEY, + user_id TEXT NOT NULL, + session_id TEXT NOT NULL, + topic_id TEXT NOT NULL, + text TEXT NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ) + """ + ) + cur.execute( + "CREATE INDEX IF NOT EXISTS idx_smriti_entries_user_session_topic ON memory_entries(user_id, session_id, topic_id, created_at DESC)" + ) + + +def _fernet_key(secret: str) -> bytes: + raw = secret.encode("utf-8") + try: + base64.urlsafe_b64decode(raw) + if len(raw) == 44: + return raw + except Exception: + pass + return base64.urlsafe_b64encode(hashlib.sha256(raw).digest()) + + +def _safe_id(value: str) -> str: + return re.sub(r"[^a-zA-Z0-9_.-]+", "_", value.strip()) or "default" + + +def _entry(session_id: str, topic_id: str, text: str, metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]: + return { + "session_id": session_id, + "topic_id": topic_id, + "text": text, + "metadata": metadata or {}, + "created_at": time.time(), + } + + +def _rank_entries( + entries: List[Dict[str, Any]], + session_id: Optional[str], + topic_id: Optional[str], + query: str, + k: int, +) -> List[Dict[str, Any]]: + scoped = [ + entry + for entry in entries + if (not session_id or entry.get("session_id") == session_id) + and (not topic_id or entry.get("topic_id") == topic_id) + ] + if not query.strip(): + return sorted(scoped, key=lambda item: item.get("created_at", 0), reverse=True)[:k] + q_terms = set(re.findall(r"[a-z0-9']+", query.lower())) + scored = [] + for entry in scoped: + terms = set(re.findall(r"[a-z0-9']+", entry.get("text", "").lower())) + overlap = len(q_terms & terms) / max(1, len(q_terms | terms)) + recency = entry.get("created_at", 0) + scored.append((overlap, recency, entry)) + scored.sort(key=lambda item: (item[0], item[1]), reverse=True) + return [entry for _, _, entry in scored[:k]] diff --git a/smriti_vendor/smriti/cli.py b/smriti_vendor/smriti/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..4775cf5b177ce8eaf6cf17fc3256d61747b5dc5e --- /dev/null +++ b/smriti_vendor/smriti/cli.py @@ -0,0 +1,279 @@ +"""Command-line tools for local Smriti AI experimentation.""" + +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import sys +from pathlib import Path +from typing import Any + +from cryptography.fernet import Fernet + +from .backends import build_backend +from .config import configure_environment_from_file, write_default_config +from .core import MemPalaceLite + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + prog="smriti-cli", + description="Smriti AI local CLI for configuration, memory, API, and benchmarks.", + ) + parser.add_argument("--config", help="Path to config.yaml. Defaults to SMRITI_CONFIG_PATH or ./config.yaml.") + parser.add_argument("--memory-path", default="smriti_memory.json") + parser.add_argument("--session-id", default="default", help="User/session id for isolated memory.") + parser.add_argument("--topic-id", default="general") + parser.add_argument( + "--backend", + choices=["json", "sqlite", "redis", "postgres"], + help="Durable backend. Defaults to config.yaml or SMRITI_MEMORY_BACKEND.", + ) + parser.add_argument("--backend-path", help="JSON root or SQLite path for local durable backends.") + parser.add_argument("--retrieval-mode", choices=["semantic", "tfidf"], default="semantic") + + subparsers = parser.add_subparsers(dest="command", required=True) + + init_parser = subparsers.add_parser("init", help="Create a config.yaml template.") + init_parser.add_argument("path", nargs="?", default="config.yaml") + init_parser.add_argument("--overwrite", action="store_true") + + config_parser = subparsers.add_parser("config", help="Configuration helpers.") + config_sub = config_parser.add_subparsers(dest="config_command", required=True) + wizard_parser = config_sub.add_parser("wizard", help="Create config.yaml with backend defaults.") + wizard_parser.add_argument("--path", default="config.yaml") + wizard_parser.add_argument("--backend", choices=["json", "sqlite", "redis", "postgres"]) + wizard_parser.add_argument("--encrypt", action="store_true", help="Generate a local Fernet encryption key.") + wizard_parser.add_argument("--overwrite", action="store_true") + + server_parser = subparsers.add_parser("start-server", help="Start the FastAPI server.") + server_parser.add_argument("--host", default=None) + server_parser.add_argument("--port", type=int, default=None) + server_parser.add_argument("--reload", action="store_true") + + chat_parser = subparsers.add_parser("chat", help="Store a message and show retrieved memory.") + chat_parser.add_argument("message") + + load_parser = subparsers.add_parser("load", help="Load memory JSON and print a summary.") + load_parser.add_argument("path", nargs="?") + + save_parser = subparsers.add_parser("save", help="Save memory to a path.") + save_parser.add_argument("path", nargs="?") + + delete_parser = subparsers.add_parser("delete", help="Delete all memory for a user.") + delete_parser.add_argument("--path", help="Optional memory JSON path to remove as well.") + + memory_parser = subparsers.add_parser("memory", help="Namespaced memory operations.") + memory_sub = memory_parser.add_subparsers(dest="memory_command", required=True) + memory_save = memory_sub.add_parser("save", help="Save memory for the configured user.") + memory_save.add_argument("path", nargs="?") + memory_load = memory_sub.add_parser("load", help="Load memory for the configured user.") + memory_load.add_argument("path", nargs="?") + memory_delete = memory_sub.add_parser("delete", help="Delete memory for the configured user.") + memory_delete.add_argument("--path", help="Optional memory JSON path to remove as well.") + + graph_parser = subparsers.add_parser("graph_query", help="Query the knowledge graph.") + graph_parser.add_argument("entity") + graph_parser.add_argument("--depth", type=int, default=1) + + benchmark_parser = subparsers.add_parser("benchmark", help="Run the Gemma 4 benchmark suite.") + benchmark_parser.add_argument("--max-new-tokens", type=int, default=80) + + args = parser.parse_args(argv) + if args.config: + os.environ["SMRITI_CONFIG_PATH"] = args.config + configure_environment_from_file(args.config) + + if args.command == "init": + path = write_default_config(args.path, overwrite=args.overwrite) + print(f"Wrote config template: {path}") + return 0 + + if args.command == "config" and args.config_command == "wizard": + path = _write_wizard_config( + args.path, + backend=args.backend or _prompt_backend(), + encrypt=args.encrypt, + overwrite=args.overwrite, + ) + print(f"Wrote Smriti AI config: {path}") + print(_config_next_steps(path)) + return 0 + + if args.command == "start-server": + return _start_server(host=args.host, port=args.port, reload=args.reload) + + if args.command == "benchmark": + return _run_benchmark(args.max_new_tokens) + + memory = _load_or_new(args.memory_path, args.retrieval_mode, args.session_id, args.topic_id) + + if args.command == "chat": + context = memory.get_context(query=args.message, session_id=args.session_id, topic_id=args.topic_id) + facts = memory.extract_facts("", user_input=args.message) + for fact in facts: + memory.add_fact(fact, session_id=args.session_id, topic_id=args.topic_id) + memory.add_to_history("User: " + args.message, "user_input") + memory.save(args.memory_path) + _save_to_backend_if_requested(args, memory) + print(context or "No relevant prior memory.") + if facts: + print("\nStored facts:") + for fact in facts: + print(f"- {fact}") + return 0 + + if args.command == "load" or (args.command == "memory" and args.memory_command == "load"): + load_path = getattr(args, "path", None) or args.memory_path + if Path(load_path).exists(): + loaded = MemPalaceLite.load(load_path, retrieval_mode=args.retrieval_mode) + else: + state = _backend(args).load(args.session_id) + if state is None: + raise SystemExit(f"No memory found for user {args.session_id!r}.") + loaded = MemPalaceLite.from_dict(state, retrieval_mode=args.retrieval_mode) + print(_summary(loaded)) + return 0 + + if args.command == "save" or (args.command == "memory" and args.memory_command == "save"): + memory.save(getattr(args, "path", None) or args.memory_path) + _save_to_backend_if_requested(args, memory) + print(_summary(memory)) + return 0 + + if args.command == "delete" or (args.command == "memory" and args.memory_command == "delete"): + print(_delete_memory(args)) + return 0 + + if args.command == "graph_query": + triples = memory.knowledge_graph.query_graph( + args.session_id, + args.entity, + depth=args.depth, + topic_id=args.topic_id, + ) + print(json.dumps([triple.__dict__ for triple in triples], indent=2)) + return 0 + + return 1 + + +def _load_or_new(path: str, retrieval_mode: str, session_id: str, topic_id: str) -> MemPalaceLite: + if Path(path).exists(): + memory = MemPalaceLite.load(path, retrieval_mode=retrieval_mode) + memory.session_id = session_id + memory.topic_id = topic_id + return memory + return MemPalaceLite(retrieval_mode=retrieval_mode, session_id=session_id, topic_id=topic_id) + + +def _summary(memory: MemPalaceLite) -> str: + state: dict[str, Any] = memory.to_dict() + semantic_sessions = state.get("semantic_memory", {}).get("sessions", {}) if state.get("semantic_memory") else {} + return json.dumps( + { + "retrieval_mode": state["retrieval_mode"], + "facts": len(state["key_facts"]), + "history": len(state["history"]), + "semantic_sessions": list(semantic_sessions.keys()), + }, + indent=2, + ) + + +def _backend(args: argparse.Namespace): + kwargs = {} + if args.backend_path: + if args.backend == "sqlite": + kwargs["path"] = args.backend_path + else: + kwargs["root"] = args.backend_path + return build_backend(args.backend, **kwargs) + + +def _save_to_backend_if_requested(args: argparse.Namespace, memory: MemPalaceLite) -> None: + if args.backend or args.backend_path or os.getenv("SMRITI_AUTOSAVE", "0").lower() in {"1", "true", "yes"}: + _backend(args).save(args.session_id, memory.to_dict()) + + +def _delete_memory(args: argparse.Namespace) -> str: + deleted_file = False + target = Path(getattr(args, "path", None) or args.memory_path) + if target.exists(): + target.unlink() + deleted_file = True + deleted_backend = _backend(args).delete_user(args.session_id) + return json.dumps( + {"user_id": args.session_id, "deleted_file": deleted_file, "deleted_backend": deleted_backend}, + indent=2, + ) + + +def _write_wizard_config(path: str, backend: str, encrypt: bool, overwrite: bool) -> Path: + config_path = write_default_config(path, overwrite=overwrite) + text = config_path.read_text(encoding="utf-8") + text = text.replace("backend: json", f"backend: {backend}") + if encrypt: + key = Fernet.generate_key().decode("utf-8") + text = text.replace('encryption_key: ""', f'encryption_key: "{key}"') + config_path.write_text(text, encoding="utf-8") + return config_path + + +def _prompt_backend() -> str: + if not sys.stdin.isatty(): + return "json" + options = ["json", "sqlite", "redis", "postgres"] + print("Choose a memory backend:") + for idx, option in enumerate(options, start=1): + print(f" {idx}. {option}") + answer = input("Backend [1=json]: ").strip() + if not answer: + return "json" + if answer.isdigit() and 1 <= int(answer) <= len(options): + return options[int(answer) - 1] + if answer.lower() in options: + return answer.lower() + print("Unknown backend; using json.") + return "json" + + +def _config_next_steps(path: Path) -> str: + return "\n".join( + [ + "Next steps:", + f" export SMRITI_CONFIG_PATH={path}", + " smriti-cli start-server --host 0.0.0.0 --port 8000", + " smriti-cli --session-id alex chat \"My name is Alex and I work at Ocean Lab.\"", + ] + ) + + +def _start_server(host: str | None, port: int | None, reload: bool) -> int: + import uvicorn + + uvicorn.run( + "smriti.api:app", + host=host or os.getenv("SMRITI_HOST", "0.0.0.0"), + port=port or int(os.getenv("SMRITI_PORT", "8000")), + reload=reload, + ) + return 0 + + +def _run_benchmark(max_new_tokens: int) -> int: + cmd = [ + sys.executable, + "benchmarks/run_benchmarks.py", + "--model-preset", + "gemma4", + "--max-new-tokens", + str(max_new_tokens), + ] + return subprocess.call(cmd) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/smriti_vendor/smriti/config.py b/smriti_vendor/smriti/config.py new file mode 100644 index 0000000000000000000000000000000000000000..0c42a08e1eef3eecdeb58434159025277dc87abf --- /dev/null +++ b/smriti_vendor/smriti/config.py @@ -0,0 +1,151 @@ +"""Configuration helpers for Smriti AI deployments.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import yaml + + +DEFAULT_CONFIG_PATH = Path("config.yaml") + + +@dataclass +class SmritiConfig: + """Runtime configuration loaded from YAML and environment variables.""" + + backend: str = "json" + memory_dir: str = "data/memory" + sqlite_path: str = "data/smriti_memory.sqlite3" + redis_url: str = "redis://localhost:6379/0" + postgres_dsn: str = "" + encryption_key: str = "" + autosave: bool = False + host: str = "0.0.0.0" + port: int = 8000 + cors_origins: list[str] = field(default_factory=lambda: ["*"]) + + +def load_config(path: str | Path | None = None) -> SmritiConfig: + """Load a Smriti AI YAML config, returning sensible defaults when absent.""" + + config_path = Path(path or os.getenv("SMRITI_CONFIG_PATH") or DEFAULT_CONFIG_PATH) + if not config_path.exists(): + return SmritiConfig() + data = _expand_env(yaml.safe_load(config_path.read_text(encoding="utf-8")) or {}) + memory = data.get("memory", data) + api = data.get("api", {}) + security = data.get("security", {}) + cors_origins = list(api.get("cors_origins", ["*"])) + if os.getenv("SMRITI_CORS_ORIGINS"): + cors_origins = [ + origin.strip() + for origin in os.environ["SMRITI_CORS_ORIGINS"].split(",") + if origin.strip() + ] or ["*"] + return SmritiConfig( + backend=str(os.getenv("SMRITI_MEMORY_BACKEND", memory.get("backend", "json"))), + memory_dir=str(os.getenv("SMRITI_MEMORY_DIR", memory.get("memory_dir", "data/memory"))), + sqlite_path=str( + os.getenv("SMRITI_SQLITE_PATH", memory.get("sqlite_path", "data/smriti_memory.sqlite3")) + ), + redis_url=str( + os.getenv("SMRITI_REDIS_URL", memory.get("redis_url", "redis://localhost:6379/0")) + ), + postgres_dsn=str(os.getenv("SMRITI_POSTGRES_DSN", memory.get("postgres_dsn", ""))), + encryption_key=str( + os.getenv("SMRITI_MEMORY_KEY", security.get("encryption_key", memory.get("encryption_key", ""))) + ), + autosave=_bool_env("SMRITI_AUTOSAVE", memory.get("autosave", False)), + host=str(os.getenv("SMRITI_HOST", api.get("host", "0.0.0.0"))), + port=int(os.getenv("SMRITI_PORT", api.get("port", 8000))), + cors_origins=cors_origins, + ) + + +def _expand_env(value: Any) -> Any: + """Recursively expand `${VAR}` placeholders in YAML values. + + Empty unresolved placeholders intentionally become empty strings instead of + literal secrets such as `${SMRITI_MEMORY_KEY}`. This keeps sample configs + safe to commit while still allowing production deployments to inject secrets + through the environment. + """ + + if isinstance(value, str): + expanded = os.path.expandvars(value) + if expanded == value and value.startswith("${") and value.endswith("}"): + return "" + return expanded + if isinstance(value, list): + return [_expand_env(item) for item in value] + if isinstance(value, dict): + return {key: _expand_env(item) for key, item in value.items()} + return value + + +def _bool_env(name: str, default: Any) -> bool: + raw = os.getenv(name) + if raw is None: + return bool(default) + return raw.lower() in {"1", "true", "yes", "on"} + + +def configure_environment_from_file(path: str | Path | None = None) -> SmritiConfig: + """Apply config defaults to environment variables without overwriting users.""" + + config = load_config(path) + defaults: dict[str, Any] = { + "SMRITI_MEMORY_BACKEND": config.backend, + "SMRITI_MEMORY_DIR": config.memory_dir, + "SMRITI_SQLITE_PATH": config.sqlite_path, + "SMRITI_REDIS_URL": config.redis_url, + "SMRITI_POSTGRES_DSN": config.postgres_dsn, + "SMRITI_AUTOSAVE": "1" if config.autosave else "0", + "SMRITI_HOST": config.host, + "SMRITI_PORT": str(config.port), + "SMRITI_CORS_ORIGINS": ",".join(config.cors_origins), + } + if config.encryption_key: + defaults["SMRITI_MEMORY_KEY"] = config.encryption_key + for key, value in defaults.items(): + if value not in {"", None}: + os.environ.setdefault(key, str(value)) + return config + + +def write_default_config(path: str | Path = DEFAULT_CONFIG_PATH, overwrite: bool = False) -> Path: + """Write a customer-friendly config template.""" + + config_path = Path(path) + if config_path.exists() and not overwrite: + return config_path + config_path.parent.mkdir(parents=True, exist_ok=True) + config_path.write_text(DEFAULT_CONFIG_YAML, encoding="utf-8") + return config_path + + +DEFAULT_CONFIG_YAML = """# Smriti AI local configuration. +memory: + # Options: json, sqlite, redis, postgres + backend: json + memory_dir: data/memory + sqlite_path: data/smriti_memory.sqlite3 + redis_url: redis://localhost:6379/0 + postgres_dsn: "" + autosave: true + +security: + # Optional Fernet-compatible secret or passphrase-derived key. + # Prefer setting SMRITI_MEMORY_KEY in production instead of committing a key. + encryption_key: "" + +api: + host: 0.0.0.0 + port: 8000 + cors_origins: + - "*" +""" diff --git a/smriti_vendor/smriti/core.py b/smriti_vendor/smriti/core.py new file mode 100644 index 0000000000000000000000000000000000000000..af1ad9a18a9200af9f9d4e726e8cb7f8f6ea0149 --- /dev/null +++ b/smriti_vendor/smriti/core.py @@ -0,0 +1,481 @@ +import json +import math +import re +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Optional + +from .knowledge_graph import KnowledgeGraphMemory +from .semantic_memory import DEFAULT_EMBEDDING_MODEL, SemanticMemory + + +@dataclass +class MemoryEntry: + content: str + timestamp: int + relevance_score: float = 1.0 + category: str = "general" + + +class MemPalaceLite: + """ + Structured external memory for SLMs. + + The legacy TF-IDF-style fact store is preserved for compatibility, while + the default retrieval path now uses hierarchical semantic memory: + sessions -> topics -> embedded facts. + + Decay formula: relevance * exp(-decay_rate * age_in_steps) + """ + + def __init__( + self, + max_history: int = 10, + max_facts: int = 20, + decay_rate: float = 0.05, + retrieval_mode: str = "semantic", + session_id: str = "default", + topic_id: str = "general", + embedding_model_name: str = DEFAULT_EMBEDDING_MODEL, + semantic_memory: Optional[SemanticMemory] = None, + knowledge_graph: Optional[KnowledgeGraphMemory] = None, + max_entries_per_topic: Optional[int] = None, + **semantic_kwargs: Any, + ): + if retrieval_mode not in {"tfidf", "semantic"}: + raise ValueError("retrieval_mode must be 'tfidf' or 'semantic'.") + + self.history: List[MemoryEntry] = [] + self.key_facts: List[MemoryEntry] = [] + self.patterns: List[str] = [] + self.max_history = max_history + self.max_facts = max_facts + self.decay_rate = decay_rate + self.step_counter = 0 + self.retrieval_mode = retrieval_mode + self.session_id = session_id + self.topic_id = topic_id + self.embedding_model_name = embedding_model_name + + self.semantic_memory = semantic_memory + if self.retrieval_mode == "semantic" and self.semantic_memory is None: + self.semantic_memory = SemanticMemory( + embedding_model_name=embedding_model_name, + decay_rate=decay_rate, + max_entries_per_topic=max_entries_per_topic or max_facts, + **semantic_kwargs, + ) + self.knowledge_graph = knowledge_graph or KnowledgeGraphMemory() + + # ------------------------------------------------------------------ + # Core memory operations + # ------------------------------------------------------------------ + + def _decayed_relevance(self, entry: MemoryEntry) -> float: + age = self.step_counter - entry.timestamp + return entry.relevance_score * math.exp(-self.decay_rate * age) + + def add_to_history(self, interaction: str, category: str = "interaction"): + self.step_counter += 1 + self.history.append( + MemoryEntry(interaction, self.step_counter, category=category) + ) + if self.semantic_memory is not None: + self.semantic_memory.step_counter = max( + self.semantic_memory.step_counter, + self.step_counter, + ) + if len(self.history) > self.max_history: + self.history = self.history[-self.max_history :] + + def add_fact( + self, + fact: str, + relevance: float = 1.0, + session_id: Optional[str] = None, + topic_id: Optional[str] = None, + ): + fact = fact.strip() + if not fact or len(fact) < 8: + return + for entry in self.key_facts: + if ( + fact.lower()[:40] in entry.content.lower() + or entry.content.lower()[:40] in fact.lower() + ): + return + + session_id = session_id or self.session_id + topic_id = topic_id or self.topic_id + self.key_facts.append( + MemoryEntry( + fact, + self.step_counter, + relevance_score=relevance, + category="fact", + ) + ) + + if self.semantic_memory is not None: + self.semantic_memory.add_entry( + session_id, + topic_id, + fact, + metadata={"category": "fact", "relevance": relevance}, + ) + self.knowledge_graph.add_statement(session_id, topic_id, fact) + + if len(self.key_facts) > self.max_facts: + self.key_facts.sort(key=self._decayed_relevance, reverse=True) + self.key_facts = self.key_facts[: self.max_facts] + + def add_pattern(self, pattern: str): + if pattern and pattern not in self.patterns: + self.patterns.append(pattern) + + # ------------------------------------------------------------------ + # Retrieval and context assembly + # ------------------------------------------------------------------ + + def retrieve_facts( + self, + query: str = "", + k: int = 5, + session_id: Optional[str] = None, + topic_id: Optional[str] = None, + ) -> List[str]: + session_id = session_id or self.session_id + topic_id = topic_id or self.topic_id + + if self.retrieval_mode == "semantic" and self.semantic_memory is not None: + if query.strip(): + return [ + result.entry.text + for result in self.semantic_memory.retrieve( + session_id, topic_id, query, k=k + ) + ] + return [ + entry.text + for entry in self.semantic_memory.recent_entries(session_id, topic_id, k=k) + ] + + return self._retrieve_tfidf(query=query, k=k) + + def get_context( + self, + query: str = "", + top_facts: int = 5, + top_history: int = 3, + session_id: Optional[str] = None, + topic_id: Optional[str] = None, + include_graph: bool = True, + ) -> str: + session_id = session_id or self.session_id + topic_id = topic_id or self.topic_id + parts = [] + + facts = self.retrieve_facts(query, k=top_facts, session_id=session_id, topic_id=topic_id) + if facts: + facts_text = "\n".join(f" * {fact}" for fact in facts[:top_facts]) + parts.append(f"[REMEMBERED FACTS]\n{facts_text}") + + if include_graph: + graph_facts = self._graph_context(query, session_id=session_id, topic_id=topic_id) + if graph_facts: + graph_text = "\n".join(f" * {fact}" for fact in graph_facts[:top_facts]) + parts.append(f"[RELATED GRAPH FACTS]\n{graph_text}") + + if self.history: + hist_text = "\n".join( + f" {entry.content}" for entry in self.history[-top_history:] + ) + parts.append(f"[RECENT CONVERSATION]\n{hist_text}") + if self.patterns: + pat_text = "\n".join(f" -> {pattern}" for pattern in self.patterns[-2:]) + parts.append(f"[REASONING PATTERNS]\n{pat_text}") + return "\n\n".join(parts) + + def _retrieve_tfidf(self, query: str, k: int) -> List[str]: + if not self.key_facts: + return [] + if not query.strip(): + sorted_facts = sorted( + self.key_facts, key=self._decayed_relevance, reverse=True + ) + return [entry.content for entry in sorted_facts[:k]] + + try: + from sklearn.feature_extraction.text import TfidfVectorizer + from sklearn.metrics.pairwise import cosine_similarity + + corpus = [query] + [entry.content for entry in self.key_facts] + tfidf = TfidfVectorizer(stop_words="english").fit_transform(corpus) + sims = cosine_similarity(tfidf[0:1], tfidf[1:])[0] + scored = [ + (float(sim) * self._decayed_relevance(entry), entry) + for sim, entry in zip(sims, self.key_facts) + ] + except Exception: + query_terms = set(re.findall(r"[a-z0-9']+", query.lower())) + scored = [] + for entry in self.key_facts: + terms = set(re.findall(r"[a-z0-9']+", entry.content.lower())) + overlap = len(query_terms & terms) / max(1, len(query_terms | terms)) + scored.append((overlap * self._decayed_relevance(entry), entry)) + + scored.sort(key=lambda item: item[0], reverse=True) + return [entry.content for _, entry in scored[:k]] + + def _graph_context(self, query: str, session_id: str, topic_id: str) -> List[str]: + terms = _graph_query_terms(query) + if not terms and any(word in query.lower() for word in ["me", "my", "mine"]): + terms = ["user"] + graph_texts: List[str] = [] + seen = set() + for term in terms: + triples = self.knowledge_graph.query_graph( + session_id, term, depth=1, topic_id=topic_id + ) + for text in self.knowledge_graph.triples_to_text(triples): + if text not in seen: + seen.add(text) + graph_texts.append(text) + return graph_texts + + # ------------------------------------------------------------------ + # Fact extraction + # ------------------------------------------------------------------ + + def extract_facts(self, text: str, user_input: str = "") -> List[str]: + """ + Extract factual sentences from user_input + model text. + + user_input is combined with text so injected personal facts + ("My name is Jordan") are captured. Question sentences are + filtered by first-word check to avoid storing distractors. + + Secondary extraction is restricted to user_input only to prevent + noisy model narration from polluting the fact store. + """ + question_starters = ( + "what", + "where", + "when", + "who", + "how", + "why", + "is", + "are", + "do", + "does", + "did", + "can", + "will", + "could", + "would", + "should", + ) + patterns = [ + r"my name is ([\w ]+)", + r"i am (?:a |an )?([\w ]+)", + r"i work (?:as|at|for) ([\w ]+)", + r"i(?:'m| am) (?:currently )?(?:studying|working on|researching|building) ([\w ]+)", + r"i (?:live|am based) (?:in|at) ([\w ,]+)", + ] + + def _non_question_sentences(src: str) -> List[str]: + out = [] + for sentence in re.split(r"[.!?]", src): + sentence = sentence.strip() + if not sentence: + continue + first = sentence.lower().split()[0] if sentence.split() else "" + if first not in question_starters: + out.append(sentence) + return out + + combined = (user_input + " " + text).strip() + combined_sentences = _non_question_sentences(combined) + user_sentences = _non_question_sentences(user_input) + + facts = [] + + # Primary: regex patterns on combined sentences + for sent in combined_sentences: + for pattern in patterns: + if re.search(pattern, sent.lower()) and 8 < len(sent) < 200: + facts.append(sent[:200]) + break + + # Secondary: general facts -- user_input sentences only + if len(facts) < 3: + for sent in user_sentences: + if 20 < len(sent) < 150: + if any( + kw in sent.lower() + for kw in [ + " is ", + " are ", + " was ", + "capital", + "located", + "known for", + ] + ): + if sent not in facts: + facts.append(sent[:150]) + + seen, unique = set(), [] + for fact in facts: + key = fact.lower()[:30] + if key not in seen: + seen.add(key) + unique.append(fact) + return unique[:4] + + # ------------------------------------------------------------------ + # Persistence + # ------------------------------------------------------------------ + + def save(self, path: str): + with open(path, "w", encoding="utf-8") as handle: + json.dump(self.to_dict(), handle, indent=2) + print( + f"Memory saved -> {path} " + f"({len(self.key_facts)} facts, {len(self.history)} history)" + ) + + @classmethod + def load(cls, path: str, **kwargs: Any) -> "MemPalaceLite": + with open(path, encoding="utf-8") as handle: + data = json.load(handle) + inst = cls.from_dict(data, **kwargs) + print( + f"Memory loaded <- {path} " + f"({len(inst.key_facts)} facts, {len(inst.history)} history)" + ) + return inst + + @classmethod + def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> "MemPalaceLite": + semantic_data = data.get("semantic_memory") + graph_data = data.get("knowledge_graph") + + semantic_kwargs = dict(kwargs) + max_history = semantic_kwargs.pop("max_history", data.get("max_history", 10)) + max_facts = semantic_kwargs.pop("max_facts", data.get("max_facts", 20)) + decay_rate = semantic_kwargs.pop("decay_rate", data.get("decay_rate", 0.05)) + retrieval_mode = semantic_kwargs.pop( + "retrieval_mode", + data.get("retrieval_mode", "semantic" if semantic_data else "tfidf"), + ) + session_id = semantic_kwargs.pop("session_id", data.get("session_id", "default")) + topic_id = semantic_kwargs.pop("topic_id", data.get("topic_id", "general")) + embedding_model_name = semantic_kwargs.pop( + "embedding_model_name", + data.get("embedding_model_name", DEFAULT_EMBEDDING_MODEL), + ) + semantic_memory = semantic_kwargs.pop("semantic_memory", None) + knowledge_graph = semantic_kwargs.pop("knowledge_graph", None) + if semantic_data and semantic_memory is None: + semantic_memory = SemanticMemory.from_dict(semantic_data, **semantic_kwargs) + + inst = cls( + max_history=max_history, + max_facts=max_facts, + decay_rate=decay_rate, + retrieval_mode=retrieval_mode, + session_id=session_id, + topic_id=topic_id, + embedding_model_name=embedding_model_name, + semantic_memory=semantic_memory, + knowledge_graph=knowledge_graph + or (KnowledgeGraphMemory.from_dict(graph_data) if graph_data else None), + **semantic_kwargs, + ) + inst.history = [MemoryEntry(**entry) for entry in data.get("history", [])] + inst.key_facts = [MemoryEntry(**entry) for entry in data.get("key_facts", [])] + inst.patterns = list(data.get("patterns", [])) + inst.step_counter = data.get("step_counter", 0) + if ( + inst.retrieval_mode == "semantic" + and inst.semantic_memory is not None + and not semantic_data + ): + for entry in inst.key_facts: + inst.semantic_memory.add_entry( + inst.session_id, + inst.topic_id, + entry.content, + metadata={ + "category": entry.category, + "relevance": entry.relevance_score, + "migrated_from_legacy": True, + }, + ) + if not graph_data: + for entry in inst.key_facts: + inst.knowledge_graph.add_statement( + inst.session_id, + inst.topic_id, + entry.content, + ) + return inst + + def to_dict(self) -> Dict[str, Any]: + return { + "version": 2, + "retrieval_mode": self.retrieval_mode, + "session_id": self.session_id, + "topic_id": self.topic_id, + "max_history": self.max_history, + "max_facts": self.max_facts, + "decay_rate": self.decay_rate, + "embedding_model_name": self.embedding_model_name, + "history": [asdict(entry) for entry in self.history], + "key_facts": [asdict(entry) for entry in self.key_facts], + "patterns": self.patterns, + "step_counter": self.step_counter, + "semantic_memory": self.semantic_memory.to_dict() + if self.semantic_memory is not None + else None, + "knowledge_graph": self.knowledge_graph.to_dict(), + } + + +def _graph_query_terms(query: str) -> List[str]: + terms = [] + lowered = query.lower() + if any(word in lowered for word in ["me", "my", "mine", "myself"]): + terms.append("user") + for match in re.finditer(r"\b[A-Z][a-zA-Z0-9_ -]{2,}\b", query): + terms.append(match.group(0)) + for pattern in [ + r"about ([\w .,'-]+)", + r"remember ([\w .,'-]+)", + r"where do i work", + r"what is my name", + r"who\s+\w+\s+([\w .,'-]+)", + r"what did ([\w .,'-]+?)\s+\w+", + ]: + found = re.search(pattern, lowered) + if found: + if found.groups(): + terms.append(found.group(1)) + else: + terms.append("user") + stop = { + "who", "what", "when", "where", "why", "how", "did", "does", "do", + "is", "are", "the", "a", "an", "discover", "discovered", "work", + } + for token in re.findall(r"[a-zA-Z0-9][a-zA-Z0-9'-]{2,}", lowered): + if token not in stop: + terms.append(token) + deduped = [] + seen = set() + for term in terms: + cleaned = term.strip(" ?.,") + if cleaned and cleaned not in seen: + seen.add(cleaned) + deduped.append(cleaned) + return deduped diff --git a/smriti_vendor/smriti/gifp.py b/smriti_vendor/smriti/gifp.py new file mode 100644 index 0000000000000000000000000000000000000000..de8448809a6e8b357c457632961a0448580fc058 --- /dev/null +++ b/smriti_vendor/smriti/gifp.py @@ -0,0 +1,8 @@ +"""Backward-compatible GIFP import path.""" + +from .identity_fingerprint import IdentityFingerprint + + +class GIFPLite(IdentityFingerprint): + """Compatibility alias for the embedding-based GIFP v1.0 implementation.""" + diff --git a/smriti_vendor/smriti/identity_fingerprint.py b/smriti_vendor/smriti/identity_fingerprint.py new file mode 100644 index 0000000000000000000000000000000000000000..1d20de46b9da8403362612cd3fb0a8f4e59ea762 --- /dev/null +++ b/smriti_vendor/smriti/identity_fingerprint.py @@ -0,0 +1,267 @@ +"""Embedding-based identity governance for Smriti AI.""" + +from __future__ import annotations + +import math +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence + +import numpy as np + +from .semantic_memory import DEFAULT_EMBEDDING_MODEL, _HashingEmbeddingModel, _normalize + + +@dataclass +class IdentityCheck: + """Result of comparing an output to the current persona fingerprint.""" + + distance: float + threshold: float + needs_refinement: bool + consistency_score: float + + +class IdentityFingerprint: + """ + GIFP v1.0 identity governance based on sentence embeddings. + + A persona vector is built from user self-descriptions and early responses. + Each model output is compared against that vector; outputs that drift too + far can be regenerated with the persona fingerprint prepended. + """ + + def __init__( + self, + role: str = "helpful assistant with persistent memory", + embedding_model_name: str = DEFAULT_EMBEDDING_MODEL, + threshold: float = 0.35, + device: str = "auto", + embedding_model: Any = None, + ): + self.role_definition = role + self.embedding_model_name = embedding_model_name + self.threshold = threshold + self.device = _resolve_device(device) + self.embedding_model = embedding_model or self._load_embedding_model() + self.constraints: List[str] = [] + self.persona_vector: Optional[np.ndarray] = None + self.persona_evidence: List[str] = [] + self.behavior_history: List[str] = [] + + # ------------------------------------------------------------------ + # Compatibility prompt methods + # ------------------------------------------------------------------ + + def set_constraints(self, constraints: List[str]) -> None: + self.constraints = constraints + + def get_identity_prompt(self) -> str: + prompt = f"You are a {self.role_definition}.\n" + if self.constraints: + prompt += "\nBehavioural guidelines:\n" + for constraint in self.constraints: + prompt += f"- {constraint}\n" + if self.persona_evidence: + prompt += "\nPersona fingerprint evidence:\n" + for evidence in self.persona_evidence[-5:]: + prompt += f"- {evidence}\n" + prompt += "\nMaintain consistency across all interactions.\n" + return prompt + + # ------------------------------------------------------------------ + # Persona vector lifecycle + # ------------------------------------------------------------------ + + def initialize_persona(self, texts: Sequence[str]) -> None: + evidence = [text.strip() for text in texts if text and text.strip()] + if not evidence: + return + vectors = np.asarray([self._embed(text) for text in evidence], dtype=np.float32) + self.persona_vector = _normalize(vectors.mean(axis=0).reshape(1, -1))[0] + self.persona_evidence = evidence[-12:] + + def observe_user_input(self, user_input: str) -> None: + """Update the persona from self-descriptive user turns.""" + + if _looks_personal(user_input) or len(self.persona_evidence) < 3: + self.update_persona(user_input, weight=0.25) + + def update_persona(self, text: str, weight: float = 0.15) -> None: + cleaned = text.strip() + if not cleaned or len(cleaned) < 8: + return + vector = self._embed(cleaned) + if self.persona_vector is None: + self.persona_vector = vector + else: + blended = (1.0 - weight) * self.persona_vector + weight * vector + self.persona_vector = _normalize(blended.reshape(1, -1))[0] + self.persona_evidence.append(cleaned[:300]) + self.persona_evidence = self.persona_evidence[-12:] + + # ------------------------------------------------------------------ + # Drift detection and refinement + # ------------------------------------------------------------------ + + def evaluate_output(self, output: str) -> IdentityCheck: + if self.persona_vector is None or not output.strip(): + return IdentityCheck(0.0, self.effective_threshold, False, 1.0) + output_vector = self._embed(output) + distance = cosine_distance(self.persona_vector, output_vector) + threshold = self.effective_threshold + return IdentityCheck( + distance=distance, + threshold=threshold, + needs_refinement=distance > threshold, + consistency_score=max(0.0, 1.0 - distance), + ) + + def check_consistency(self, output: str) -> float: + """Compatibility method returning a score instead of a structured check.""" + + return self.evaluate_output(output).consistency_score + + def refinement_pass( + self, + generate_fn: Any, + output: str, + user_input: str = "", + context: str = "", + max_tokens: int = 256, + ) -> str: + """Regenerate an output with persona evidence prepended.""" + + prompt = self._refinement_prompt(output, user_input=user_input, context=context) + try: + return generate_fn(prompt, max_tokens=max_tokens) + except TypeError: + return generate_fn(prompt) + + def ensure_aligned( + self, + output: str, + generate_fn: Any, + user_input: str = "", + context: str = "", + max_tokens: int = 256, + ) -> tuple[str, IdentityCheck]: + check = self.evaluate_output(output) + if not check.needs_refinement: + return output, check + refined = self.refinement_pass( + generate_fn, + output, + user_input=user_input, + context=context, + max_tokens=max_tokens, + ) + refined_check = self.evaluate_output(refined) + return refined, refined_check + + def record_behavior(self, output: str) -> None: + if output and len(output) > 10: + self.behavior_history.append(output[:300]) + self.behavior_history = self.behavior_history[-20:] + if len(self.persona_evidence) < 3: + self.update_persona(output, weight=0.1) + + @property + def effective_threshold(self) -> float: + adjustment = min(0.15, math.log1p(len(self.persona_evidence)) * 0.025) + return self.threshold + adjustment + + def to_dict(self) -> Dict[str, Any]: + return { + "role_definition": self.role_definition, + "embedding_model_name": self.embedding_model_name, + "threshold": self.threshold, + "constraints": self.constraints, + "persona_vector": self.persona_vector.tolist() + if self.persona_vector is not None + else None, + "persona_evidence": self.persona_evidence, + "behavior_history": self.behavior_history, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> "IdentityFingerprint": + init_kwargs = { + "role": data.get("role_definition", "helpful assistant with persistent memory"), + "embedding_model_name": data.get("embedding_model_name", DEFAULT_EMBEDDING_MODEL), + "threshold": data.get("threshold", 0.35), + } + init_kwargs.update(kwargs) + inst = cls(**init_kwargs) + vector = data.get("persona_vector") + inst.persona_vector = np.asarray(vector, dtype=np.float32) if vector else None + inst.constraints = list(data.get("constraints", [])) + inst.persona_evidence = list(data.get("persona_evidence", [])) + inst.behavior_history = list(data.get("behavior_history", [])) + return inst + + def _load_embedding_model(self) -> Any: + try: + from sentence_transformers import SentenceTransformer + + return SentenceTransformer(self.embedding_model_name, device=self.device) + except Exception: + return _HashingEmbeddingModel() + + def _embed(self, text: str) -> np.ndarray: + vector = self.embedding_model.encode( + text, + convert_to_numpy=True, + normalize_embeddings=True, + ) + vector = np.asarray(vector, dtype=np.float32) + if vector.ndim > 1: + vector = vector[0] + return _normalize(vector.reshape(1, -1))[0] + + def _refinement_prompt(self, output: str, user_input: str, context: str) -> str: + evidence = "\n".join(f"- {item}" for item in self.persona_evidence[-6:]) + constraints = "\n".join(f"- {item}" for item in self.constraints) + return ( + "[PERSONA FINGERPRINT]\n" + f"Role: {self.role_definition}\n" + f"{evidence}\n\n" + "[BEHAVIOURAL GUIDELINES]\n" + f"{constraints}\n\n" + "[RETRIEVED CONTEXT]\n" + f"{context}\n\n" + "[USER MESSAGE]\n" + f"{user_input}\n\n" + "[DRAFT OUTPUT]\n" + f"{output}\n\n" + "Regenerate the answer so it remains accurate, helpful and aligned " + "with the persona fingerprint." + ) + + +def cosine_distance(left: np.ndarray, right: np.ndarray) -> float: + left = _normalize(np.asarray(left, dtype=np.float32).reshape(1, -1))[0] + right = _normalize(np.asarray(right, dtype=np.float32).reshape(1, -1))[0] + similarity = float(np.dot(left, right)) + return max(0.0, min(2.0, 1.0 - similarity)) + + +def _looks_personal(text: str) -> bool: + lowered = text.lower() + return bool( + re.search( + r"\b(my name is|i am|i'm|i work|i live|i study|i research|i build|i prefer|my favorite|my favourite)\b", + lowered, + ) + ) + + +def _resolve_device(device: str) -> str: + if device != "auto": + return device + try: + import torch + + return "cuda" if torch.cuda.is_available() else "cpu" + except Exception: + return "cpu" diff --git a/smriti_vendor/smriti/integrations/__init__.py b/smriti_vendor/smriti/integrations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6f880f03c9070f9bc1d71163a3a59d68cb6feb8a --- /dev/null +++ b/smriti_vendor/smriti/integrations/__init__.py @@ -0,0 +1,13 @@ +"""Optional framework adapters for Smriti AI.""" + +try: + from .langchain import SmritiMemory +except Exception: # pragma: no cover - optional dependency surface. + SmritiMemory = None + +try: + from .llama_index import SmritiStorageContext +except Exception: # pragma: no cover - optional dependency surface. + SmritiStorageContext = None + +__all__ = ["SmritiMemory", "SmritiStorageContext"] diff --git a/smriti_vendor/smriti/integrations/__pycache__/__init__.cpython-310.pyc b/smriti_vendor/smriti/integrations/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45b40022020ddf24e74c2d1d2bd77c7cef3913d2 Binary files /dev/null and b/smriti_vendor/smriti/integrations/__pycache__/__init__.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/integrations/__pycache__/langchain.cpython-310.pyc b/smriti_vendor/smriti/integrations/__pycache__/langchain.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..399e9110b8dc09703ae2ce9737cc88251b3b1a9f Binary files /dev/null and b/smriti_vendor/smriti/integrations/__pycache__/langchain.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/integrations/__pycache__/llama_index.cpython-310.pyc b/smriti_vendor/smriti/integrations/__pycache__/llama_index.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..362813d2a44d9440d8a2e00707d05bdb03e7138b Binary files /dev/null and b/smriti_vendor/smriti/integrations/__pycache__/llama_index.cpython-310.pyc differ diff --git a/smriti_vendor/smriti/integrations/langchain.py b/smriti_vendor/smriti/integrations/langchain.py new file mode 100644 index 0000000000000000000000000000000000000000..3216f8f08af432d74ebd6d6dde23c61352921e6a --- /dev/null +++ b/smriti_vendor/smriti/integrations/langchain.py @@ -0,0 +1,49 @@ +"""LangChain adapter for Smriti AI memory.""" + +from __future__ import annotations + +from typing import Any + +from smriti import MemPalaceLite + + +class SmritiMemory: + """Small LangChain-compatible memory adapter. + + The class intentionally avoids a hard LangChain dependency. It implements + the common `memory_variables`, `load_memory_variables`, `save_context`, and + `clear` methods expected by chain-style integrations. + """ + + memory_key = "smriti_context" + + def __init__(self, memory: MemPalaceLite | None = None, session_id: str = "default", topic_id: str = "general"): + self.session_id = session_id + self.topic_id = topic_id + self.memory = memory or MemPalaceLite(retrieval_mode="semantic", session_id=session_id, topic_id=topic_id) + + @property + def memory_variables(self) -> list[str]: + return [self.memory_key] + + def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]: + query = str(inputs.get("input") or inputs.get("query") or "") + return { + self.memory_key: self.memory.get_context( + query=query, + session_id=self.session_id, + topic_id=self.topic_id, + ) + } + + def save_context(self, inputs: dict[str, Any], outputs: dict[str, Any]) -> None: + text = str(inputs.get("input") or inputs.get("query") or "") + response = str(outputs.get("output") or outputs.get("response") or "") + for fact in self.memory.extract_facts(response, user_input=text): + self.memory.add_fact(fact, self.session_id, self.topic_id) + self.memory.add_to_history("User: " + text, "user_input") + if response: + self.memory.add_to_history("Assistant: " + response[:200], "assistant_output") + + def clear(self) -> None: + self.memory = MemPalaceLite(retrieval_mode=self.memory.retrieval_mode, session_id=self.session_id, topic_id=self.topic_id) diff --git a/smriti_vendor/smriti/integrations/llama_index.py b/smriti_vendor/smriti/integrations/llama_index.py new file mode 100644 index 0000000000000000000000000000000000000000..d98ddf351febcc24d6426387c3abce16a545c721 --- /dev/null +++ b/smriti_vendor/smriti/integrations/llama_index.py @@ -0,0 +1,37 @@ +"""LlamaIndex-style storage adapter for Smriti AI.""" + +from __future__ import annotations + +from typing import Any + +from smriti import MemPalaceLite + + +class SmritiStorageContext: + """Minimal node storage/retrieval context backed by Smriti AI memory.""" + + def __init__(self, memory: MemPalaceLite | None = None, session_id: str = "default", topic_id: str = "general"): + self.session_id = session_id + self.topic_id = topic_id + self.memory = memory or MemPalaceLite(retrieval_mode="semantic", session_id=session_id, topic_id=topic_id) + + def add_node(self, node: Any) -> None: + text = getattr(node, "text", None) or getattr(node, "get_content", lambda: str(node))() + self.memory.add_fact(str(text), session_id=self.session_id, topic_id=self.topic_id) + + def query(self, query: str, k: int = 5) -> list[str]: + context = self.memory.get_context( + query=query, + top_facts=k, + session_id=self.session_id, + topic_id=self.topic_id, + ) + results = [] + for line in context.splitlines(): + cleaned = line.strip() + if cleaned.startswith("* "): + results.append(cleaned[2:].strip()) + return results + + def persist(self, path: str) -> None: + self.memory.save(path) diff --git a/smriti_vendor/smriti/knowledge_graph.py b/smriti_vendor/smriti/knowledge_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f9b7d57384443951ad4f8b89f82a57da664826 --- /dev/null +++ b/smriti_vendor/smriti/knowledge_graph.py @@ -0,0 +1,366 @@ +"""Session-scoped knowledge graph memory for lightweight fact traversal.""" +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional, Tuple + + +@dataclass(frozen=True) +class GraphTriple: + """A subject-relation-object fact stored in graph memory.""" + + subject: str + relation: str + object: str + topic_id: str = "general" + + def to_text(self) -> str: + return f"{self.subject} {self.relation} {self.object}." + + +class KnowledgeGraphMemory: + """ + Directed graph memory keyed by session and topic. + + NetworkX is used when installed. A tiny fallback graph keeps local tests and + imports working in environments where dependencies have not been installed + yet, but packaged installs will use `networkx.DiGraph`. + """ + + def __init__(self): + self.nx = _try_import_networkx() + self.graphs: Dict[str, Dict[str, Any]] = {} + + def add_statement( + self, + session_id: str, + topic_id: str, + statement: str, + ) -> List[GraphTriple]: + """Extract simple triples from a statement and add them to graph memory.""" + + triples = extract_triples(statement) + for triple in triples: + self.add_triple( + session_id, + triple.subject, + triple.relation, + triple.object, + topic_id=topic_id, + ) + return triples + + def add_triple( + self, + session_id: str, + subject: str, + relation: str, + object_: str, + topic_id: str = "general", + ) -> GraphTriple: + """Add a subject-relation-object edge to a session/topic graph.""" + + triple = GraphTriple( + subject=_clean_entity(subject), + relation=_clean_relation(relation), + object=_clean_entity(object_), + topic_id=topic_id, + ) + if not triple.subject or not triple.relation or not triple.object: + raise ValueError("subject, relation and object_ must be non-empty.") + + graph = self._get_graph(session_id, topic_id) + subj_key = _node_key(triple.subject) + obj_key = _node_key(triple.object) + _add_node(graph, subj_key, label=triple.subject) + _add_node(graph, obj_key, label=triple.object) + + edge_data = _get_edge_data(graph, subj_key, obj_key) or {} + relations = set(edge_data.get("relations", [])) + relations.add(triple.relation) + _add_edge(graph, subj_key, obj_key, relations=sorted(relations)) + return triple + + def query_graph( + self, + session_id: str, + query_entity: str, + depth: int = 1, + topic_id: Optional[str] = None, + ) -> List[GraphTriple]: + """Traverse outward and inward from matching entities and return facts.""" + + if not query_entity.strip(): + return [] + query_candidates = _query_candidates(query_entity) + topics = self.graphs.get(session_id, {}) + selected = ( + {topic_id: topics.get(topic_id)} + if topic_id + else topics + ) + + results: List[GraphTriple] = [] + seen: set[Tuple[str, str, str, str]] = set() + for current_topic, graph in selected.items(): + if graph is None: + continue + starts = [] + for candidate in query_candidates: + starts.extend(self._matching_nodes(graph, candidate)) + starts = list(dict.fromkeys(starts)) + for start in starts: + for subj, obj, data in _traverse_edges(graph, start, depth): + subject = _node_label(graph, subj) + object_ = _node_label(graph, obj) + for relation in data.get("relations", []): + key = (current_topic, subject, relation, object_) + if key in seen: + continue + seen.add(key) + results.append(GraphTriple(subject, relation, object_, current_topic)) + return results + + def triples_to_text(self, triples: Iterable[GraphTriple]) -> List[str]: + return [triple.to_text() for triple in triples] + + def to_dict(self) -> Dict[str, Any]: + sessions: Dict[str, Dict[str, List[Dict[str, str]]]] = {} + for session_id, topics in self.graphs.items(): + sessions[session_id] = {} + for topic_id, graph in topics.items(): + triples = [] + for subj, obj, data in _all_edges(graph): + for relation in data.get("relations", []): + triples.append( + { + "subject": _node_label(graph, subj), + "relation": relation, + "object": _node_label(graph, obj), + } + ) + sessions[session_id][topic_id] = triples + return {"version": 1, "sessions": sessions} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "KnowledgeGraphMemory": + inst = cls() + for session_id, topics in data.get("sessions", {}).items(): + for topic_id, triples in topics.items(): + for triple in triples: + inst.add_triple( + session_id, + triple["subject"], + triple["relation"], + triple["object"], + topic_id=topic_id, + ) + return inst + + def _get_graph(self, session_id: str, topic_id: str) -> Any: + session = self.graphs.setdefault(session_id, {}) + if topic_id not in session: + session[topic_id] = self.nx.DiGraph() if self.nx else _SimpleDiGraph() + return session[topic_id] + + def _matching_nodes(self, graph: Any, query: str) -> List[str]: + query_key = _node_key(query) + matches = [] + for node in _nodes(graph): + label = _node_label(graph, node) + node_key = _node_key(label) + if query_key in node_key or node_key in query_key: + matches.append(node) + return matches + + +def extract_triples(text: str) -> List[GraphTriple]: + """Extract simple triples with conservative regex patterns.""" + + triples: List[GraphTriple] = [] + sentence_parts: List[str] = [] + for sentence in [s.strip() for s in re.split(r"[.!?\n]+", text) if s.strip()]: + sentence_parts.extend(_split_clauses(sentence)) + patterns = [ + (r"^(my name) is ([\w .,'-]+)$", "user", "name is", 2), + (r"^i (?:am|'m) (?:a |an )?([\w .,'-]+)$", "user", "is", 1), + (r"^i work (?:at|for) ([\w .,'-]+)$", "user", "works at", 1), + (r"^i work as (?:a |an )?([\w .,'-]+)$", "user", "works as", 1), + (r"^i (?:live|am based) (?:in|at) ([\w .,'-]+)$", "user", "lives in", 1), + (r"^i (?:study|studying|research|researching|build|building) ([\w .,'-]+)$", "user", "works on", 1), + (r"^([\w .,'-]+?) is (?:a |an |the )?([\w .,'-]+)$", 1, "is", 2), + (r"^([\w .,'-]+?) works (?:at|for) ([\w .,'-]+)$", 1, "works at", 2), + (r"^([\w .,'-]+?) lives in ([\w .,'-]+)$", 1, "lives in", 2), + (r"^([\w .,'-]+?) (?:studies|researches|builds|uses|likes) ([\w .,'-]+)$", 1, "relates to", 2), + ] + for sentence in sentence_parts: + lowered = sentence.lower().strip() + for pattern, subject_group, relation, object_group in patterns: + match = re.match(pattern, lowered, flags=re.IGNORECASE) + if not match: + continue + subject = subject_group if isinstance(subject_group, str) else match.group(subject_group) + object_ = match.group(object_group) + triples.append( + GraphTriple( + subject=_clean_entity(subject), + relation=_clean_relation(relation), + object=_clean_entity(object_), + ) + ) + break + return triples + + +class _SimpleDiGraph: + def __init__(self): + self.node_attrs: Dict[str, Dict[str, Any]] = {} + self.edges: Dict[Tuple[str, str], Dict[str, Any]] = {} + + def add_node(self, node: str, **attrs: Any) -> None: + self.node_attrs.setdefault(node, {}).update(attrs) + + def add_edge(self, source: str, target: str, **attrs: Any) -> None: + self.edges[(source, target)] = attrs + + def has_edge(self, source: str, target: str) -> bool: + return (source, target) in self.edges + + def get_edge_data(self, source: str, target: str, default: Any = None) -> Any: + return self.edges.get((source, target), default) + + def successors(self, node: str) -> List[str]: + return [target for source, target in self.edges if source == node] + + def predecessors(self, node: str) -> List[str]: + return [source for source, target in self.edges if target == node] + + @property + def nodes(self) -> Dict[str, Dict[str, Any]]: + return self.node_attrs + + +def _try_import_networkx() -> Any: + try: + import networkx as nx + + return nx + except Exception: + return None + + +def _clean_entity(value: str) -> str: + return re.sub(r"\s+", " ", value.strip(" .,'\"")).strip() + + +def _clean_relation(value: str) -> str: + return re.sub(r"\s+", " ", value.strip().lower()) + + +def _node_key(value: str) -> str: + return _clean_entity(value).lower() + + +def _split_clauses(sentence: str) -> List[str]: + clauses = [] + remaining = sentence.strip() + while True: + match = re.search(r"\s+and\s+(?=i\s)", remaining, flags=re.IGNORECASE) + if not match: + clauses.append(remaining) + break + clauses.append(remaining[: match.start()]) + remaining = remaining[match.end() :] + return [clause.strip() for clause in clauses if clause.strip()] + + +def _query_candidates(query: str) -> List[str]: + cleaned = _clean_entity(query.rstrip("?")) + lowered = cleaned.lower() + candidates = [cleaned] + question_patterns = [ + r"^who\s+\w+\s+(.+)$", + r"^what\s+did\s+(.+?)\s+\w+", + r"^what\s+is\s+(.+)$", + r"^tell\s+me\s+about\s+(.+)$", + ] + for pattern in question_patterns: + match = re.search(pattern, lowered) + if match: + candidates.append(match.group(1)) + for match in re.finditer(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b", query): + candidates.append(match.group(0)) + stop = { + "who", "what", "when", "where", "why", "how", "did", "does", "do", + "is", "are", "was", "were", "the", "a", "an", "discover", "discovered", + } + for token in re.findall(r"[a-zA-Z0-9][a-zA-Z0-9'-]{2,}", lowered): + if token not in stop: + candidates.append(token) + deduped = [] + seen = set() + for candidate in candidates: + candidate = _clean_entity(candidate) + key = candidate.lower() + if candidate and key not in seen: + seen.add(key) + deduped.append(candidate) + return deduped + + +def _add_node(graph: Any, node: str, **attrs: Any) -> None: + graph.add_node(node, **attrs) + + +def _add_edge(graph: Any, source: str, target: str, **attrs: Any) -> None: + graph.add_edge(source, target, **attrs) + + +def _get_edge_data(graph: Any, source: str, target: str) -> Optional[Dict[str, Any]]: + return graph.get_edge_data(source, target, default=None) + + +def _node_label(graph: Any, node: str) -> str: + try: + return graph.nodes[node].get("label", node) + except Exception: + return graph.nodes.get(node, {}).get("label", node) + + +def _nodes(graph: Any) -> List[str]: + try: + return list(graph.nodes) + except Exception: + return list(graph.nodes.keys()) + + +def _all_edges(graph: Any) -> Iterable[Tuple[str, str, Dict[str, Any]]]: + if hasattr(graph, "edges") and callable(getattr(graph, "edges")): + return graph.edges(data=True) + return [(source, target, data) for (source, target), data in graph.edges.items()] + + +def _edge_between(graph: Any, source: str, target: str) -> Dict[str, Any]: + return graph.get_edge_data(source, target, default={}) or {} + + +def _traverse_edges(graph: Any, start: str, depth: int) -> Iterable[Tuple[str, str, Dict[str, Any]]]: + max_depth = max(1, depth) + queue: List[Tuple[str, int]] = [(start, 0)] + visited = {start} + emitted: set[Tuple[str, str]] = set() + + while queue: + node, current_depth = queue.pop(0) + if current_depth >= max_depth: + continue + neighbours = list(graph.successors(node)) + list(graph.predecessors(node)) + for neighbour in neighbours: + for source, target in ((node, neighbour), (neighbour, node)): + if graph.has_edge(source, target) and (source, target) not in emitted: + emitted.add((source, target)) + yield source, target, _edge_between(graph, source, target) + if neighbour not in visited: + visited.add(neighbour) + queue.append((neighbour, current_depth + 1)) diff --git a/smriti_vendor/smriti/macp.py b/smriti_vendor/smriti/macp.py new file mode 100644 index 0000000000000000000000000000000000000000..969d03de06e83eadb17dd527b1a2d14c910aea42 --- /dev/null +++ b/smriti_vendor/smriti/macp.py @@ -0,0 +1,56 @@ +from dataclasses import dataclass +from typing import List + + +@dataclass +class ReasoningStep: + step_id: int + input_context: str + model_output: str + confidence: float + next_action: str + + +class MACPLite: + """ + Multi-Agent Continuity Protocol (Lite). + + Maintains a structured reasoning chain across inference turns. + Each turn records input context, model output, confidence score, + and the recommended next action — enabling auditable continuity. + """ + + def __init__(self): + self.reasoning_chain: List[ReasoningStep] = [] + self.current_step = 0 + self.context_buffer = '' + + def start_chain(self, initial_input: str): + self.context_buffer = initial_input + + def add_step(self, input_ctx: str, output: str, + confidence: float, next_action: str): + self.reasoning_chain.append(ReasoningStep( + self.current_step, + input_ctx[:200], + output[:300], + confidence, + next_action, + )) + self.current_step += 1 + self.context_buffer = output + + def get_chain_summary(self) -> str: + if not self.reasoning_chain: + return 'No reasoning chain yet.' + lines = ['[REASONING CHAIN]'] + for s in self.reasoning_chain[-5:]: + lines.append( + f' Step {s.step_id} (conf={s.confidence:.2f}): ' + f'{s.model_output[:80]}...') + return '\n'.join(lines) + + def reset(self): + self.reasoning_chain = [] + self.current_step = 0 + self.context_buffer = '' diff --git a/smriti_vendor/smriti/mem_palace.py b/smriti_vendor/smriti/mem_palace.py new file mode 100644 index 0000000000000000000000000000000000000000..7dece854859134cec2a2ca29ae935e287035d251 --- /dev/null +++ b/smriti_vendor/smriti/mem_palace.py @@ -0,0 +1,5 @@ +"""Compatibility module for older `mempalace.mem_palace` imports.""" + +from .core import MemoryEntry, MemPalaceLite + +__all__ = ["MemoryEntry", "MemPalaceLite"] diff --git a/smriti_vendor/smriti/semantic_memory.py b/smriti_vendor/smriti/semantic_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..7a43ae3c506c1f8e1ee2910a48c0519b3982c223 --- /dev/null +++ b/smriti_vendor/smriti/semantic_memory.py @@ -0,0 +1,462 @@ +"""Semantic, hierarchical memory backed by sentence embeddings and FAISS.""" +from __future__ import annotations + +import json +import math +import re +import time +from hashlib import blake2b +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np + + +DEFAULT_EMBEDDING_MODEL = "all-MiniLM-L6-v2" +DEFAULT_ARCHIVE_PATH = "smriti_memory_archive.jsonl" + + +@dataclass +class MemoryEntry: + """One fact or turn stored in a topic-level semantic index.""" + + text: str + timestamp: int + embedding: List[float] + created_at: float = field(default_factory=time.time) + last_accessed: int = 0 + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class RetrievalResult: + """Ranked semantic retrieval result with both similarity and decay scores.""" + + entry: MemoryEntry + cosine_similarity: float + decay: float + score: float + + +@dataclass +class _TopicStore: + entries: List[MemoryEntry] = field(default_factory=list) + index: Any = None + use_faiss: bool = False + + +class _HashingEmbeddingModel: + """Small deterministic fallback used when sentence-transformers is absent.""" + + def __init__(self, dimension: int = 384): + self.dimension = dimension + + def encode( + self, + texts: Sequence[str] | str, + convert_to_numpy: bool = True, + normalize_embeddings: bool = True, + **_: Any, + ) -> np.ndarray: + single = isinstance(texts, str) + values = [texts] if single else list(texts) + vectors = np.zeros((len(values), self.dimension), dtype=np.float32) + for row, text in enumerate(values): + for token in re.findall(r"[a-z0-9']+", text.lower()): + digest = blake2b(token.encode("utf-8"), digest_size=8).digest() + idx = int.from_bytes(digest, byteorder="big") % self.dimension + vectors[row, idx] += 1.0 + if normalize_embeddings: + vectors = _normalize(vectors) + if single: + return vectors[0] + return vectors + + +class _NumpyIndex: + """In-memory cosine index with a FAISS-like search interface.""" + + def __init__(self, dimension: int): + self.dimension = dimension + self.vectors = np.empty((0, dimension), dtype=np.float32) + + def add(self, vectors: np.ndarray) -> None: + vectors = np.asarray(vectors, dtype=np.float32) + if vectors.ndim == 1: + vectors = vectors.reshape(1, -1) + self.vectors = np.vstack([self.vectors, vectors]) + + def search(self, query: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]: + if self.vectors.size == 0: + return ( + np.empty((1, 0), dtype=np.float32), + np.empty((1, 0), dtype=np.int64), + ) + query = np.asarray(query, dtype=np.float32).reshape(1, -1) + sims = self.vectors @ query[0] + order = np.argsort(-sims)[:k] + return sims[order].reshape(1, -1), order.astype(np.int64).reshape(1, -1) + + +class SemanticMemory: + """ + Hierarchical semantic memory with per-topic vector indices. + + Sessions isolate users or conversations. Each session owns topics, and + each topic owns `MemoryEntry` objects indexed by cosine similarity. + """ + + def __init__( + self, + embedding_model_name: str = DEFAULT_EMBEDDING_MODEL, + decay_rate: float = 0.05, + max_entries_per_topic: int = 200, + archive_path: str | Path = DEFAULT_ARCHIVE_PATH, + device: str = "auto", + use_gpu_index: Optional[bool] = None, + embedding_model: Any = None, + faiss_module: Any = None, + ): + self.embedding_model_name = embedding_model_name + self.decay_rate = decay_rate + self.max_entries_per_topic = max_entries_per_topic + self.archive_path = Path(archive_path) + self.step_counter = 0 + self.sessions: Dict[str, Dict[str, _TopicStore]] = {} + + self.device = _resolve_device(device) + self.embedding_model = embedding_model or self._load_embedding_model() + if faiss_module is False: + self.faiss = None + else: + self.faiss = faiss_module if faiss_module is not None else _try_import_faiss() + self.use_gpu_index = _cuda_available() if use_gpu_index is None else use_gpu_index + + def _load_embedding_model(self) -> Any: + try: + from sentence_transformers import SentenceTransformer + + return SentenceTransformer(self.embedding_model_name, device=self.device) + except Exception: + return _HashingEmbeddingModel() + + def _embed(self, text: str) -> np.ndarray: + vector = self.embedding_model.encode( + text, + convert_to_numpy=True, + normalize_embeddings=True, + ) + vector = np.asarray(vector, dtype=np.float32) + if vector.ndim > 1: + vector = vector[0] + return _normalize(vector.reshape(1, -1))[0] + + def _get_topic(self, session_id: str, topic_id: str) -> _TopicStore: + session = self.sessions.setdefault(session_id, {}) + return session.setdefault(topic_id, _TopicStore()) + + def _new_index(self, dimension: int) -> Tuple[Any, bool]: + if self.faiss is None: + return _NumpyIndex(dimension), False + + index = None + try: + index = self.faiss.IndexHNSWFlat( + dimension, 32, self.faiss.METRIC_INNER_PRODUCT + ) + index.hnsw.efSearch = 64 + index.hnsw.efConstruction = 80 + except Exception: + index = self.faiss.IndexFlatIP(dimension) + + if self.use_gpu_index and hasattr(self.faiss, "StandardGpuResources"): + try: + resources = self.faiss.StandardGpuResources() + index = self.faiss.index_cpu_to_gpu(resources, 0, index) + except Exception: + pass + return index, True + + def _rebuild_topic_index(self, topic: _TopicStore) -> None: + if not topic.entries: + topic.index = None + topic.use_faiss = False + return + vectors = np.asarray([e.embedding for e in topic.entries], dtype=np.float32) + vectors = _normalize(vectors) + topic.index, topic.use_faiss = self._new_index(vectors.shape[1]) + topic.index.add(vectors) + + def add_entry( + self, + session_id: str, + topic_id: str, + text: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> MemoryEntry: + """Embed text and add it to the session/topic memory.""" + + cleaned = text.strip() + if not cleaned: + raise ValueError("Cannot add an empty memory entry.") + + self.step_counter += 1 + vector = self._embed(cleaned) + entry = MemoryEntry( + text=cleaned, + timestamp=self.step_counter, + embedding=vector.tolist(), + last_accessed=self.step_counter, + metadata=metadata or {}, + ) + topic = self._get_topic(session_id, topic_id) + topic.entries.append(entry) + + if topic.index is None: + self._rebuild_topic_index(topic) + else: + topic.index.add(vector.reshape(1, -1).astype(np.float32)) + + self._enforce_topic_limit(session_id, topic_id) + return entry + + def retrieve( + self, + session_id: str, + topic_id: str, + query: str, + k: int = 5, + ) -> List[RetrievalResult]: + """Return top-k entries ranked by cosine similarity and age decay.""" + + topic = self.sessions.get(session_id, {}).get(topic_id) + if not topic or not topic.entries or not query.strip(): + return [] + if topic.index is None: + self._rebuild_topic_index(topic) + + candidate_k = min(len(topic.entries), max(k * 5, k)) + query_vector = self._embed(query).reshape(1, -1).astype(np.float32) + distances, indices = topic.index.search(query_vector, candidate_k) + + ranked: List[RetrievalResult] = [] + for cosine, idx in zip(distances[0], indices[0]): + if idx < 0 or idx >= len(topic.entries): + continue + entry = topic.entries[int(idx)] + age = max(0, self.step_counter - entry.timestamp) + decay = math.exp(-self.decay_rate * age) + score = max(0.0, float(cosine)) * decay + entry.last_accessed = self.step_counter + ranked.append(RetrievalResult(entry, float(cosine), decay, score)) + + ranked.sort(key=lambda item: item.score, reverse=True) + return ranked[:k] + + def recent_entries( + self, + session_id: str, + topic_id: str, + k: int = 5, + ) -> List[MemoryEntry]: + """Return the most recent entries for a session/topic.""" + + topic = self.sessions.get(session_id, {}).get(topic_id) + if not topic: + return [] + return sorted(topic.entries, key=lambda entry: entry.timestamp, reverse=True)[:k] + + def compress_topic(self, session_id: str, topic_id: str) -> Optional[MemoryEntry]: + """ + Summarise older entries, archive originals, and rebuild the topic index. + + The summariser is extractive and intentionally dependency-light. If NLTK + sentence tokenisation is available it is used; otherwise a regex splitter + keeps the method fully local. + """ + + topic = self.sessions.get(session_id, {}).get(topic_id) + if not topic or len(topic.entries) < 4: + return None + + keep_count = max(2, min(5, len(topic.entries) // 3)) + ordered = sorted(topic.entries, key=lambda entry: entry.timestamp) + older = ordered[:-keep_count] + recent = ordered[-keep_count:] + if not older: + return None + + summary_text = _extractive_summary([entry.text for entry in older]) + self._archive_entries(session_id, topic_id, older) + + self.step_counter += 1 + vector = self._embed(summary_text) + summary = MemoryEntry( + text=summary_text, + timestamp=self.step_counter, + embedding=vector.tolist(), + last_accessed=self.step_counter, + metadata={ + "compressed": True, + "source_count": len(older), + "source_timestamp_min": min(entry.timestamp for entry in older), + "source_timestamp_max": max(entry.timestamp for entry in older), + }, + ) + topic.entries = [summary] + recent + self._rebuild_topic_index(topic) + return summary + + def to_dict(self) -> Dict[str, Any]: + return { + "version": 1, + "embedding_model_name": self.embedding_model_name, + "decay_rate": self.decay_rate, + "max_entries_per_topic": self.max_entries_per_topic, + "archive_path": str(self.archive_path), + "step_counter": self.step_counter, + "sessions": { + session_id: { + topic_id: [asdict(entry) for entry in topic.entries] + for topic_id, topic in topics.items() + } + for session_id, topics in self.sessions.items() + }, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> "SemanticMemory": + init_kwargs = { + "embedding_model_name": data.get( + "embedding_model_name", DEFAULT_EMBEDDING_MODEL + ), + "decay_rate": data.get("decay_rate", 0.05), + "max_entries_per_topic": data.get("max_entries_per_topic", 200), + "archive_path": data.get("archive_path", DEFAULT_ARCHIVE_PATH), + } + init_kwargs.update(kwargs) + inst = cls(**init_kwargs) + inst.step_counter = data.get("step_counter", 0) + for session_id, topics in data.get("sessions", {}).items(): + for topic_id, entries in topics.items(): + topic = inst._get_topic(session_id, topic_id) + topic.entries = [MemoryEntry(**entry) for entry in entries] + inst._rebuild_topic_index(topic) + return inst + + def save(self, path: str | Path = "smriti_memory.json") -> None: + with open(path, "w", encoding="utf-8") as handle: + json.dump(self.to_dict(), handle, indent=2) + + @classmethod + def load(cls, path: str | Path = "smriti_memory.json", **kwargs: Any) -> "SemanticMemory": + with open(path, encoding="utf-8") as handle: + data = json.load(handle) + return cls.from_dict(data, **kwargs) + + def _enforce_topic_limit(self, session_id: str, topic_id: str) -> None: + topic = self.sessions.get(session_id, {}).get(topic_id) + if not topic or len(topic.entries) <= self.max_entries_per_topic: + return + + self.compress_topic(session_id, topic_id) + topic = self.sessions[session_id][topic_id] + if len(topic.entries) <= self.max_entries_per_topic: + return + + topic.entries.sort(key=lambda entry: (entry.last_accessed, entry.timestamp)) + evicted = topic.entries[: len(topic.entries) - self.max_entries_per_topic] + topic.entries = topic.entries[len(topic.entries) - self.max_entries_per_topic:] + self._archive_entries(session_id, topic_id, evicted, reason="evicted_lru") + self._rebuild_topic_index(topic) + + def _archive_entries( + self, + session_id: str, + topic_id: str, + entries: Iterable[MemoryEntry], + reason: str = "compressed", + ) -> None: + entries = list(entries) + if not entries: + return + self.archive_path.parent.mkdir(parents=True, exist_ok=True) + with open(self.archive_path, "a", encoding="utf-8") as handle: + for entry in entries: + payload = { + "session_id": session_id, + "topic_id": topic_id, + "reason": reason, + "archived_at": time.time(), + "entry": asdict(entry), + } + handle.write(json.dumps(payload) + "\n") + + +def _normalize(vectors: np.ndarray) -> np.ndarray: + vectors = np.asarray(vectors, dtype=np.float32) + norms = np.linalg.norm(vectors, axis=-1, keepdims=True) + norms = np.where(norms == 0.0, 1.0, norms) + return vectors / norms + + +def _try_import_faiss() -> Any: + try: + import faiss + + return faiss + except Exception: + return None + + +def _cuda_available() -> bool: + try: + import torch + + return bool(torch.cuda.is_available()) + except Exception: + return False + + +def _resolve_device(device: str) -> str: + if device != "auto": + return device + return "cuda" if _cuda_available() else "cpu" + + +def _split_sentences(text: str) -> List[str]: + try: + import nltk + + return [s.strip() for s in nltk.sent_tokenize(text) if s.strip()] + except Exception: + return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()] + + +def _extractive_summary(texts: Sequence[str], max_sentences: int = 3) -> str: + text = " ".join(t.strip() for t in texts if t.strip()) + sentences = _split_sentences(text) + if not sentences: + return text[:500] + if len(sentences) <= max_sentences: + return " ".join(sentences) + + tokens = re.findall(r"[a-z0-9']+", text.lower()) + stopwords = { + "the", "and", "a", "an", "to", "of", "in", "is", "it", "for", + "on", "with", "that", "this", "i", "you", "we", "they", "am", + } + freqs: Dict[str, int] = {} + for token in tokens: + if token not in stopwords: + freqs[token] = freqs.get(token, 0) + 1 + + scored = [] + for idx, sentence in enumerate(sentences): + words = re.findall(r"[a-z0-9']+", sentence.lower()) + score = sum(freqs.get(word, 0) for word in words) / max(1, len(words)) + scored.append((score, idx, sentence)) + selected = sorted(scored, key=lambda item: (-item[0], item[1]))[:max_sentences] + selected.sort(key=lambda item: item[1]) + return " ".join(sentence for _, _, sentence in selected)[:750] diff --git a/test_handler_local.py b/test_handler_local.py new file mode 100644 index 0000000000000000000000000000000000000000..88c3b5eb296ca4bc7abdf265f46e73a54743cf44 --- /dev/null +++ b/test_handler_local.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +"""Local smoke test for the Smriti AI Hugging Face custom handler.""" + +from __future__ import annotations + +import json +import os +from pathlib import Path + +from handler import EndpointHandler + + +def pretty(title: str, payload: dict) -> None: + print(f"\n=== {title} ===") + print(json.dumps(payload, indent=2, ensure_ascii=False)) + + +def main() -> int: + os.environ.setdefault("BASE_MODEL_ID", "sshleifer/tiny-gpt2") + os.environ.setdefault("SMRITI_MEMORY_BACKEND", "json") + os.environ.setdefault("SMRITI_MEMORY_PATH", "/tmp/smriti_hf_test.json") + os.environ.setdefault("SMRITI_RETRIEVAL_MODE", "semantic_graph_identity") + retrieval_mode = os.environ["SMRITI_RETRIEVAL_MODE"] + + handler = EndpointHandler(path=str(Path(__file__).resolve().parent)) + + pretty("health", handler({"inputs": {"operation": "health"}})) + + pretty( + "fact injection", + handler( + { + "inputs": { + "operation": "chat", + "user_id": "local-demo-user", + "message": "My name is Alex and I am a marine biologist based in Hawaii.", + "retrieval_mode": retrieval_mode, + }, + "parameters": {"max_new_tokens": 64, "return_memories": True}, + } + ), + ) + + pretty( + "distractor", + handler( + { + "inputs": { + "operation": "chat", + "user_id": "local-demo-user", + "message": "What is the capital of France?", + "retrieval_mode": retrieval_mode, + }, + "parameters": {"max_new_tokens": 64, "return_memories": True}, + } + ), + ) + + pretty( + "recall", + handler( + { + "inputs": { + "operation": "chat", + "user_id": "local-demo-user", + "message": "What do you remember about me?", + "retrieval_mode": retrieval_mode, + }, + "parameters": {"max_new_tokens": 64, "return_memories": True}, + } + ), + ) + + pretty( + "delete memory", + handler({"inputs": {"operation": "delete_memory", "user_id": "local-demo-user"}}), + ) + + pretty( + "recall after delete", + handler( + { + "inputs": { + "operation": "chat", + "user_id": "local-demo-user", + "message": "What do you remember about me?", + "retrieval_mode": retrieval_mode, + }, + "parameters": {"max_new_tokens": 64, "return_memories": True}, + } + ), + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/upload_model_repo.py b/upload_model_repo.py new file mode 100644 index 0000000000000000000000000000000000000000..54aa07852b27f9dae361bf1576f1c59fce1f2afa --- /dev/null +++ b/upload_model_repo.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +"""Upload the Smriti AI Hugging Face model-style deployment directory.""" + +from __future__ import annotations + +import argparse +import os +import shutil +import tempfile +from pathlib import Path + +from huggingface_hub import HfApi + +ROOT = Path(__file__).resolve().parents[2] +DEPLOY_DIR = Path(__file__).resolve().parent + +EXCLUDED_NAMES = { + "__pycache__", + ".DS_Store", + ".env", + ".env.local", + "smriti_memory.json", + "godelai_memory.json", +} +EXCLUDED_SUFFIXES = {".pyc", ".pyo", ".sqlite", ".sqlite3", ".db"} + + +def main() -> int: + parser = argparse.ArgumentParser(description="Upload Smriti AI as a Hugging Face model repo.") + parser.add_argument("--repo-id", required=True, help="Target model repo, e.g. YOUR_ORG/smriti-ai.") + parser.add_argument("--private", default="true", help="true or false. Defaults to true.") + parser.add_argument( + "--include-package-source", + action="store_true", + help="Vendor src/smriti into the model repo as smriti_vendor/smriti.", + ) + args = parser.parse_args() + + token = os.getenv("HF_TOKEN") + if not token: + raise SystemExit("HF_TOKEN is required in the environment. It will not be printed.") + + private = str(args.private).lower() in {"1", "true", "yes", "on"} + api = HfApi(token=token) + api.create_repo(repo_id=args.repo_id, repo_type="model", private=private, exist_ok=True) + + with tempfile.TemporaryDirectory() as tmp: + stage = Path(tmp) / "hf_model_repo" + _copy_tree(DEPLOY_DIR, stage) + if args.include_package_source: + vendor_root = stage / "smriti_vendor" + shutil.copytree(ROOT / "src" / "smriti", vendor_root / "smriti") + mempalace = ROOT / "src" / "mempalace" + if mempalace.exists(): + shutil.copytree(mempalace, vendor_root / "mempalace") + api.upload_folder( + repo_id=args.repo_id, + repo_type="model", + folder_path=str(stage), + commit_message="Deploy Smriti AI Hugging Face handler", + ) + + print(f"Uploaded Smriti AI model-style repo: https://huggingface.co/{args.repo_id}") + return 0 + + +def _copy_tree(src: Path, dst: Path) -> None: + dst.mkdir(parents=True, exist_ok=True) + for item in src.iterdir(): + if _excluded(item): + continue + target = dst / item.name + if item.is_dir(): + _copy_tree(item, target) + else: + shutil.copy2(item, target) + + +def _excluded(path: Path) -> bool: + if path.name in EXCLUDED_NAMES or path.suffix in EXCLUDED_SUFFIXES: + return True + parts = set(path.parts) + if "data" in parts or ".git" in parts or ".cache" in parts: + return True + if path.name.startswith("smriti_memory"): + return True + return False + + +if __name__ == "__main__": + raise SystemExit(main())