Text Generation
English
smriti-memory-ai
smriti-ai
memory
agent-memory
long-term-memory
external-memory
training-free
frozen-model
inference-time-augmentation
retrieval-augmented-generation
rag
semantic-search
knowledge-graph
identity-continuity
small-language-model
small-language-models
ai-agent
gemma
gemma-4
qwen
qwen2.5
llama
llama-3.2
phi-3
| """Hugging Face custom inference handler for Smriti AI. | |
| This file is intentionally deployment glue. Core memory, retrieval, graph, and | |
| identity behavior comes from the installed `smriti` package. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import sys | |
| import time | |
| import urllib.error | |
| import urllib.request | |
| from pathlib import Path | |
| from threading import RLock | |
| from typing import Any, Dict, List, Optional, Tuple | |
| VENDOR_SRC = Path(__file__).resolve().parent / "smriti_vendor" | |
| if VENDOR_SRC.exists() and str(VENDOR_SRC) not in sys.path: | |
| sys.path.insert(0, str(VENDOR_SRC)) | |
| from smriti import IdentityFingerprint, MemPalaceLite, SmritiAILite # noqa: E402 | |
| from smriti.backends import ( # noqa: E402 | |
| JsonBackend, | |
| MemoryBackend, | |
| MemoryCipher, | |
| PostgresBackend, | |
| RedisBackend, | |
| SqliteBackend, | |
| ) | |
| from smriti.production_safety import ( # noqa: E402 | |
| GEMMA4_MODEL_ID, | |
| is_production_mode, | |
| validate_model_id_for_environment, | |
| ) | |
| LOGGER = logging.getLogger("smriti.hf_handler") | |
| if not LOGGER.handlers: | |
| logging.basicConfig(level=os.getenv("SMRITI_LOG_LEVEL", "INFO")) | |
| DEFAULT_CONFIG = { | |
| "project": "Smriti AI", | |
| "base_model": GEMMA4_MODEL_ID, | |
| "retrieval_mode": "semantic_graph_identity", | |
| "memory_backend": "json", | |
| "public_demo": False, | |
| "max_memory_entries": 1000, | |
| "enable_identity": True, | |
| "enable_graph": True, | |
| "enable_encryption": True, | |
| } | |
| class EndpointHandler: | |
| """Hugging Face custom inference endpoint handler.""" | |
| def __init__(self, path: str = ""): | |
| self.root = _resolve_root(path) | |
| self.config = _load_config(self.root / "config.json") | |
| self.lock = RLock() | |
| self.memories: Dict[str, MemPalaceLite] = {} | |
| self.identities: Dict[str, IdentityFingerprint] = {} | |
| self.backend_warning: Optional[str] = None | |
| self.endpoint_url = os.getenv("HF_ENDPOINT_URL", "").strip() | |
| base_model_env = os.getenv("BASE_MODEL_ID") | |
| base_model_raw = base_model_env if base_model_env is not None else self.config.get("base_model", "") | |
| self.base_model_id = _clean_model_id( | |
| base_model_raw, | |
| allow_empty=bool(self.endpoint_url) or (base_model_env is not None and not is_production_mode()), | |
| ) | |
| self.hf_token = os.getenv("HF_TOKEN", "").strip() | |
| self.default_retrieval_mode = os.getenv( | |
| "SMRITI_RETRIEVAL_MODE", | |
| str(self.config.get("retrieval_mode", "semantic_graph_identity")), | |
| ) | |
| self.max_memory_entries = _int_env( | |
| "SMRITI_MAX_MEMORY_ENTRIES", | |
| int(self.config.get("max_memory_entries", 1000)), | |
| ) | |
| self.public_demo = _bool_env("SMRITI_PUBLIC_DEMO", bool(self.config.get("public_demo", False))) | |
| self.enable_graph_default = bool(self.config.get("enable_graph", True)) | |
| self.enable_identity_default = bool(self.config.get("enable_identity", True)) | |
| self.enable_encryption = bool(self.config.get("enable_encryption", True)) | |
| self.backend, self.backend_name = self._init_backend() | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = "cpu" | |
| if self.endpoint_url: | |
| LOGGER.info( | |
| "Smriti AI handler using remote model endpoint; backend=%s retrieval=%s", | |
| self.backend_name, | |
| self.default_retrieval_mode, | |
| ) | |
| elif self.base_model_id: | |
| self._load_local_model(self.base_model_id) | |
| else: | |
| LOGGER.warning( | |
| "No BASE_MODEL_ID or HF_ENDPOINT_URL configured; handler will run memory-only." | |
| ) | |
| LOGGER.info( | |
| "Smriti AI handler ready: base_model=%s remote_endpoint=%s backend=%s retrieval=%s encryption=%s public_demo=%s", | |
| self.base_model_id or "memory-only", | |
| bool(self.endpoint_url), | |
| self.backend_name, | |
| self.default_retrieval_mode, | |
| self.enable_encryption and bool(os.getenv("SMRITI_ENCRYPTION_KEY") or os.getenv("SMRITI_MEMORY_KEY")), | |
| self.public_demo, | |
| ) | |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| start = time.perf_counter() | |
| try: | |
| inputs, parameters = _normalize_request(data) | |
| operation = str(inputs.get("operation", "chat")).lower() | |
| if operation == "health": | |
| return self._health(start) | |
| if operation == "delete_memory": | |
| return self._delete_memory(inputs, start) | |
| if operation != "chat": | |
| return _error(f"Unsupported operation: {operation}", start) | |
| return self._chat(inputs, parameters, start) | |
| except Exception as exc: # Defensive boundary for endpoint runtimes. | |
| LOGGER.exception("Unhandled Smriti AI handler error") | |
| return _error(f"handler_error:{exc.__class__.__name__}: {exc}", start) | |
| # ------------------------------------------------------------------ | |
| # Operation handlers | |
| # ------------------------------------------------------------------ | |
| def _chat( | |
| self, | |
| inputs: Dict[str, Any], | |
| parameters: Dict[str, Any], | |
| start: float, | |
| ) -> Dict[str, Any]: | |
| user_id = str(inputs.get("user_id") or "").strip() | |
| message = str(inputs.get("message") or "").strip() | |
| topic_id = str(inputs.get("topic_id") or "general").strip() or "general" | |
| if not user_id: | |
| return _error("user_id is required", start) | |
| if not message: | |
| return _error("message is required for chat operation", start) | |
| retrieval_mode = str(inputs.get("retrieval_mode") or self.default_retrieval_mode) | |
| base_retrieval = _base_retrieval_mode(retrieval_mode) | |
| include_graph = self.enable_graph_default and "graph" in retrieval_mode | |
| identity_enabled = self.enable_identity_default and "identity" in retrieval_mode | |
| with self.lock: | |
| memory = self._get_memory(user_id, topic_id, base_retrieval) | |
| context, retrieved_memories, graph_facts, retrieval_warning = self._retrieve_context( | |
| memory, | |
| user_id, | |
| topic_id, | |
| message, | |
| include_graph, | |
| ) | |
| identity = self._get_identity(user_id, identity_enabled) | |
| agent = SmritiAILite( | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| retrieval_mode=base_retrieval, | |
| session_id=user_id, | |
| topic_id=topic_id, | |
| memory=memory, | |
| identity=identity, | |
| auto_device=False, | |
| ) | |
| agent.build_prompt = lambda user_input: _build_prompt( | |
| agent, | |
| memory, | |
| user_id, | |
| topic_id, | |
| user_input, | |
| include_graph, | |
| identity_enabled, | |
| ) | |
| generation_calls = 0 | |
| def generate(prompt: str, max_tokens: int = 256) -> str: | |
| nonlocal generation_calls | |
| generation_calls += 1 | |
| return self._generate_text(prompt, parameters, max_tokens=max_tokens) | |
| agent._generate = generate # type: ignore[method-assign] | |
| try: | |
| response = agent.chat(message) | |
| except Exception as exc: | |
| LOGGER.exception("Model generation failed") | |
| return _error(f"model_generation_failed:{exc.__class__.__name__}: {exc}", start) | |
| response = _stabilize_recall_answer(message, response, retrieved_memories, graph_facts) | |
| _replace_last_assistant_history(memory, response) | |
| identity_check = agent.identity.evaluate_output(response) if identity_enabled else None | |
| save_warning = self._save_memory(user_id, memory) | |
| warnings = [item for item in [self.backend_warning, retrieval_warning, save_warning] if item] | |
| return { | |
| "response": response, | |
| "retrieved_memories": retrieved_memories, | |
| "graph_facts": graph_facts, | |
| "identity": { | |
| "enabled": identity_enabled, | |
| "drift_score": float(identity_check.distance) if identity_check else 0.0, | |
| "refinement_triggered": generation_calls > 1, | |
| }, | |
| "latency_ms": round((time.perf_counter() - start) * 1000, 3), | |
| "backend": self.backend_name, | |
| "retrieval_mode": retrieval_mode, | |
| "warnings": warnings, | |
| } | |
| def _delete_memory(self, inputs: Dict[str, Any], start: float) -> Dict[str, Any]: | |
| user_id = str(inputs.get("user_id") or "").strip() | |
| if not user_id: | |
| return _error("user_id is required for delete_memory operation", start) | |
| with self.lock: | |
| existed_cache = self.memories.pop(user_id, None) is not None | |
| self.identities.pop(user_id, None) | |
| try: | |
| deleted_backend = self.backend.delete_user(user_id) | |
| except Exception as exc: | |
| LOGGER.exception("Memory backend delete failed") | |
| return _error(f"backend_delete_failed:{exc.__class__.__name__}: {exc}", start) | |
| return { | |
| "deleted": bool(existed_cache or deleted_backend), | |
| "user_id": user_id, | |
| "latency_ms": round((time.perf_counter() - start) * 1000, 3), | |
| "backend": self.backend_name, | |
| } | |
| def _health(self, start: float) -> Dict[str, Any]: | |
| return { | |
| "status": "ok", | |
| "project": "Smriti AI", | |
| "base_model": self.base_model_id or ("remote-endpoint" if self.endpoint_url else "memory-only"), | |
| "backend": self.backend_name, | |
| "retrieval_mode": self.default_retrieval_mode, | |
| "latency_ms": round((time.perf_counter() - start) * 1000, 3), | |
| } | |
| # ------------------------------------------------------------------ | |
| # Runtime setup | |
| # ------------------------------------------------------------------ | |
| def _init_backend(self) -> Tuple[MemoryBackend, str]: | |
| encryption_key = os.getenv("SMRITI_ENCRYPTION_KEY") or os.getenv("SMRITI_MEMORY_KEY") | |
| if encryption_key: | |
| os.environ["SMRITI_MEMORY_KEY"] = encryption_key | |
| cipher = MemoryCipher(encryption_key if self.enable_encryption else None) | |
| redis_url = os.getenv("REDIS_URL") or os.getenv("SMRITI_REDIS_URL") | |
| postgres_dsn = os.getenv("POSTGRES_DSN") or os.getenv("SMRITI_POSTGRES_DSN") | |
| selected = (os.getenv("SMRITI_MEMORY_BACKEND") or self.config.get("memory_backend") or "json").lower() | |
| memory_path = os.getenv("SMRITI_MEMORY_PATH", "/tmp/smriti_hf_memory") | |
| if redis_url: | |
| return RedisBackend(url=redis_url, cipher=cipher), "redis" | |
| if postgres_dsn: | |
| return PostgresBackend(dsn=postgres_dsn, cipher=cipher), "postgres" | |
| if selected == "redis": | |
| return RedisBackend(url=redis_url or "redis://localhost:6379/0", cipher=cipher), "redis" | |
| if selected in {"postgres", "postgresql"}: | |
| return PostgresBackend(dsn=postgres_dsn or "", cipher=cipher), "postgres" | |
| if selected == "sqlite": | |
| return SqliteBackend(path=memory_path, cipher=cipher), "sqlite" | |
| return JsonBackend(root=_json_root(memory_path), cipher=cipher), "json" | |
| def _load_local_model(self, model_id: str) -> None: | |
| try: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| except Exception as exc: | |
| raise RuntimeError("Install torch and transformers to load a local base model.") from exc | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float32 | |
| if self.device == "cuda": | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| kwargs = {"token": self.hf_token} if self.hf_token else {} | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) | |
| if getattr(self.tokenizer, "pad_token_id", None) is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| try: | |
| self.model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype, **kwargs) | |
| except TypeError: | |
| self.model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **kwargs) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| LOGGER.info("Loaded local base model %s on %s", model_id, self.device) | |
| # ------------------------------------------------------------------ | |
| # Memory and generation helpers | |
| # ------------------------------------------------------------------ | |
| def _get_memory(self, user_id: str, topic_id: str, retrieval_mode: str) -> MemPalaceLite: | |
| self.backend_warning = None | |
| if user_id not in self.memories: | |
| state = None | |
| try: | |
| state = self.backend.load(user_id) | |
| except Exception as exc: | |
| LOGGER.exception("Memory backend load failed; starting empty memory") | |
| self.backend_warning = f"backend_load_failed:{exc.__class__.__name__}" | |
| if state: | |
| memory = MemPalaceLite.from_dict( | |
| state, | |
| retrieval_mode=retrieval_mode, | |
| session_id=user_id, | |
| topic_id=topic_id, | |
| max_facts=self.max_memory_entries, | |
| max_entries_per_topic=self.max_memory_entries, | |
| ) | |
| else: | |
| memory = MemPalaceLite( | |
| retrieval_mode=retrieval_mode, | |
| session_id=user_id, | |
| topic_id=topic_id, | |
| max_facts=self.max_memory_entries, | |
| max_entries_per_topic=self.max_memory_entries, | |
| ) | |
| self.memories[user_id] = memory | |
| memory = self.memories[user_id] | |
| if memory.retrieval_mode != retrieval_mode: | |
| memory = MemPalaceLite.from_dict( | |
| memory.to_dict(), | |
| retrieval_mode=retrieval_mode, | |
| session_id=user_id, | |
| topic_id=topic_id, | |
| max_facts=self.max_memory_entries, | |
| max_entries_per_topic=self.max_memory_entries, | |
| ) | |
| self.memories[user_id] = memory | |
| memory.session_id = user_id | |
| memory.topic_id = topic_id | |
| return memory | |
| def _get_identity(self, user_id: str, enabled: bool) -> IdentityFingerprint: | |
| if user_id not in self.identities: | |
| threshold = 0.35 if enabled else 2.0 | |
| self.identities[user_id] = IdentityFingerprint( | |
| role="helpful AI assistant with persistent memory", | |
| threshold=threshold, | |
| ) | |
| identity = self.identities[user_id] | |
| if not enabled: | |
| identity.threshold = 2.0 | |
| return identity | |
| def _retrieve_context( | |
| self, | |
| memory: MemPalaceLite, | |
| user_id: str, | |
| topic_id: str, | |
| message: str, | |
| include_graph: bool, | |
| ) -> Tuple[str, List[str], List[str], Optional[str]]: | |
| try: | |
| context = memory.get_context( | |
| query=message, | |
| session_id=user_id, | |
| topic_id=topic_id, | |
| include_graph=include_graph, | |
| ) | |
| retrieved_memories = memory.retrieve_facts( | |
| message, | |
| k=5, | |
| session_id=user_id, | |
| topic_id=topic_id, | |
| ) | |
| graph_facts = _section_bullets(context, "[RELATED GRAPH FACTS]") if include_graph else [] | |
| return context, retrieved_memories, graph_facts, None | |
| except Exception as exc: | |
| LOGGER.exception("Memory retrieval failed") | |
| return "", [], [], f"retrieval_failed:{exc.__class__.__name__}" | |
| def _save_memory(self, user_id: str, memory: MemPalaceLite) -> Optional[str]: | |
| try: | |
| self.backend.save(user_id, memory.to_dict()) | |
| return None | |
| except Exception as exc: | |
| LOGGER.exception("Memory backend save failed") | |
| return f"backend_save_failed:{exc.__class__.__name__}" | |
| def _generate_text(self, prompt: str, parameters: Dict[str, Any], max_tokens: int = 256) -> str: | |
| max_new_tokens = int(parameters.get("max_new_tokens", max_tokens) or max_tokens) | |
| temperature = float(parameters.get("temperature", 0.7)) | |
| top_p = float(parameters.get("top_p", 0.9)) | |
| if self.endpoint_url: | |
| return self._generate_remote(prompt, max_new_tokens, temperature, top_p) | |
| if self.model is not None and self.tokenizer is not None: | |
| return self._generate_local(prompt, max_new_tokens, temperature, top_p) | |
| return _memory_only_answer(prompt) | |
| def _generate_local( | |
| self, | |
| prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| ) -> str: | |
| import torch | |
| messages = [{"role": "user", "content": prompt}] | |
| try: | |
| formatted = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| except Exception: | |
| formatted = prompt | |
| inputs = self.tokenizer( | |
| formatted, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=2048, | |
| ) | |
| inputs = {key: value.to(self.device) for key, value in inputs.items()} | |
| generate_kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": temperature > 0, | |
| "pad_token_id": getattr(self.tokenizer, "eos_token_id", None), | |
| } | |
| if temperature > 0: | |
| generate_kwargs["temperature"] = temperature | |
| generate_kwargs["top_p"] = top_p | |
| with torch.inference_mode(): | |
| output = self.model.generate(**inputs, **generate_kwargs) | |
| return self.tokenizer.decode( | |
| output[0, inputs["input_ids"].shape[1] :].detach().cpu(), | |
| skip_special_tokens=True, | |
| ).strip() | |
| def _generate_remote( | |
| self, | |
| prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| ) -> str: | |
| payload = { | |
| "inputs": prompt, | |
| "parameters": { | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| }, | |
| } | |
| headers = {"Content-Type": "application/json"} | |
| if self.hf_token: | |
| headers["Authorization"] = f"Bearer {self.hf_token}" | |
| request = urllib.request.Request( | |
| self.endpoint_url, | |
| data=json.dumps(payload).encode("utf-8"), | |
| headers=headers, | |
| method="POST", | |
| ) | |
| try: | |
| with urllib.request.urlopen(request, timeout=120) as response: # noqa: S310 | |
| raw = response.read().decode("utf-8") | |
| except urllib.error.HTTPError as exc: | |
| body = exc.read().decode("utf-8", errors="replace") | |
| raise RuntimeError(f"remote endpoint HTTP {exc.code}: {body[:300]}") from exc | |
| parsed = json.loads(raw) | |
| return _extract_generated_text(parsed) | |
| # ---------------------------------------------------------------------- | |
| # Request, context, and formatting helpers | |
| # ---------------------------------------------------------------------- | |
| def _resolve_root(path: str) -> Path: | |
| if path: | |
| root = Path(path).resolve() | |
| return root.parent if root.is_file() else root | |
| return Path(__file__).resolve().parent | |
| def _load_config(path: Path) -> Dict[str, Any]: | |
| if not path.exists(): | |
| return dict(DEFAULT_CONFIG) | |
| data = json.loads(path.read_text(encoding="utf-8")) | |
| config = dict(DEFAULT_CONFIG) | |
| config.update(data) | |
| return config | |
| def _normalize_request(data: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
| if not isinstance(data, dict): | |
| raise ValueError("Request body must be a JSON object.") | |
| if "inputs" in data: | |
| inputs = data.get("inputs") or {} | |
| if isinstance(inputs, str): | |
| inputs = {"message": inputs} | |
| parameters = data.get("parameters") or {} | |
| else: | |
| inputs = data | |
| parameters = data.get("parameters") or {} | |
| if not isinstance(inputs, dict) or not isinstance(parameters, dict): | |
| raise ValueError("inputs and parameters must be JSON objects.") | |
| return inputs, parameters | |
| def _base_retrieval_mode(mode: str) -> str: | |
| return "tfidf" if str(mode).lower().startswith("tfidf") else "semantic" | |
| def _build_prompt( | |
| agent: SmritiAILite, | |
| memory: MemPalaceLite, | |
| user_id: str, | |
| topic_id: str, | |
| user_input: str, | |
| include_graph: bool, | |
| identity_enabled: bool, | |
| ) -> str: | |
| identity = agent.identity.get_identity_prompt() if identity_enabled else "" | |
| context = memory.get_context( | |
| query=user_input, | |
| session_id=user_id, | |
| topic_id=topic_id, | |
| include_graph=include_graph, | |
| ) | |
| parts = [part for part in [identity.strip(), context.strip(), user_input.strip()] if part] | |
| return "\n\n".join(parts) | |
| def _section_bullets(context: str, heading: str) -> List[str]: | |
| if heading not in context: | |
| return [] | |
| after = context.split(heading, 1)[1] | |
| chunks = re.split(r"\n\[[A-Z ]+\]", after, maxsplit=1) | |
| section = chunks[0] | |
| bullets = [] | |
| for line in section.splitlines(): | |
| cleaned = line.strip() | |
| if cleaned.startswith("*"): | |
| bullets.append(cleaned.lstrip("* ").strip()) | |
| return bullets | |
| def _memory_only_answer(prompt: str) -> str: | |
| facts = _section_bullets(prompt, "[REMEMBERED FACTS]") | |
| graph = _section_bullets(prompt, "[RELATED GRAPH FACTS]") | |
| combined = facts + [item for item in graph if item not in facts] | |
| if combined: | |
| return "I remember: " + "; ".join(combined[:5]) | |
| return "Memory updated. No prior relevant context was found." | |
| def _is_recall_query(message: str) -> bool: | |
| lowered = message.lower() | |
| return any( | |
| phrase in lowered | |
| for phrase in [ | |
| "remember", | |
| "what do you know about me", | |
| "who am i", | |
| "where do i work", | |
| "what is my name", | |
| "what do i do", | |
| ] | |
| ) | |
| def _stabilize_recall_answer( | |
| message: str, | |
| response: str, | |
| retrieved_memories: List[str], | |
| graph_facts: List[str], | |
| ) -> str: | |
| if not _is_recall_query(message): | |
| return response | |
| combined = retrieved_memories + [item for item in graph_facts if item not in retrieved_memories] | |
| if not combined: | |
| return response | |
| if _mentions_memory_terms(response, combined): | |
| return response | |
| return "I remember: " + "; ".join(combined[:5]) | |
| def _mentions_memory_terms(response: str, memories: List[str]) -> bool: | |
| response_terms = set(re.findall(r"[a-z0-9']{4,}", response.lower())) | |
| memory_terms = set() | |
| for memory in memories: | |
| memory_terms.update(re.findall(r"[a-z0-9']{4,}", memory.lower())) | |
| return bool(response_terms & memory_terms) | |
| def _replace_last_assistant_history(memory: MemPalaceLite, response: str) -> None: | |
| if memory.history and memory.history[-1].category == "assistant_output": | |
| memory.history[-1].content = "Assistant: " + response[:200] | |
| def _extract_generated_text(parsed: Any) -> str: | |
| if isinstance(parsed, list) and parsed: | |
| return _extract_generated_text(parsed[0]) | |
| if isinstance(parsed, dict): | |
| for key in ["generated_text", "response", "text", "output"]: | |
| value = parsed.get(key) | |
| if isinstance(value, str): | |
| return value.strip() | |
| if "outputs" in parsed: | |
| return _extract_generated_text(parsed["outputs"]) | |
| if isinstance(parsed, str): | |
| return parsed.strip() | |
| raise RuntimeError("Remote endpoint did not return generated text.") | |
| def _json_root(memory_path: str) -> Path: | |
| path = Path(memory_path) | |
| if path.suffix.lower() in {".json", ".jsonl"}: | |
| return path.with_suffix("") | |
| return path | |
| def _clean_model_id(value: str, *, allow_empty: bool = False) -> str: | |
| value = (value or "").strip() | |
| return validate_model_id_for_environment( | |
| value, | |
| context="Smriti AI Hugging Face handler", | |
| allow_empty=allow_empty, | |
| ) | |
| def _bool_env(name: str, default: bool) -> bool: | |
| raw = os.getenv(name) | |
| if raw is None: | |
| return default | |
| return raw.strip().lower() in {"1", "true", "yes", "on"} | |
| def _int_env(name: str, default: int) -> int: | |
| try: | |
| return int(os.getenv(name, str(default))) | |
| except ValueError: | |
| return default | |
| def _error(message: str, start: float) -> Dict[str, Any]: | |
| return { | |
| "error": message, | |
| "latency_ms": round((time.perf_counter() - start) * 1000, 3), | |
| } | |