import hashlib import os import re import threading import time from dataclasses import dataclass from typing import Dict, Iterator, Optional import torch from transformers import TextIteratorStreamer, pipeline DEFAULT_MODEL_ID = os.getenv("MODEL_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0") CACHE_TTL_SECONDS = int(os.getenv("RESPONSE_CACHE_TTL", "600")) @dataclass class CacheEntry: value: str expires_at: float class ModelManager: def __init__(self, model_id: str = DEFAULT_MODEL_ID) -> None: self.model_id = model_id self._generator = None self._tokenizer = None self._lock = threading.Lock() self._cache: Dict[str, CacheEntry] = {} def load(self) -> None: if self._generator is not None: return with self._lock: if self._generator is not None: return try: self._generator = pipeline( task="text-generation", model=self.model_id, tokenizer=self.model_id, device=-1, model_kwargs={ "torch_dtype": torch.float32, }, ) except Exception: # Final fallback for constrained runtimes with strict model loading behavior. self._generator = pipeline( task="text-generation", model=self.model_id, tokenizer=self.model_id, device=-1, ) self._tokenizer = self._generator.tokenizer @staticmethod def dynamic_token_budget(message: str) -> int: words = len(message.split()) lower = message.lower() complexity_hints = ( "explain", "compare", "analyze", "step by step", "architecture", "strategy", "detailed", ) if words <= 12 and not any(hint in lower for hint in complexity_hints): return 120 if words <= 35: return 360 return 720 @staticmethod def _looks_incomplete(text: str, max_new_tokens: int) -> bool: stripped = text.strip() if not stripped: return True likely_truncated = len(stripped.split()) >= int(max_new_tokens * 0.75) clean_endings = (".", "!", "?", "\"", "'", ")", "]", "}") has_clean_ending = stripped.endswith(clean_endings) return likely_truncated and not has_clean_ending @staticmethod def _build_prompt(message: str, memory_context: str, tool_context: str) -> str: system = ( "You are a friendly, helpful general AI assistant. " "Use a warm, respectful tone and practical wording. " "Be concise when possible, but complete. " "Use prior context if relevant. If tools are provided, ground your answer in them. " "Output only the assistant answer. Do not write role labels like 'User:' or 'Assistant:'. " "Do not add unrelated sections such as 'Conclusion:' unless the user explicitly asked for them." ) parts = [f"System: {system}"] if memory_context: parts.append(f"Conversation memory:\n{memory_context}") if tool_context: parts.append(f"Tool results:\n{tool_context}") parts.append(f"User: {message}") parts.append("Assistant:") return "\n\n".join(parts) def _cache_key(self, prompt: str, max_new_tokens: int) -> str: material = f"{self.model_id}|{max_new_tokens}|{prompt}".encode("utf-8") return hashlib.sha256(material).hexdigest() def _get_cached(self, key: str) -> Optional[str]: entry = self._cache.get(key) if not entry: return None if time.time() > entry.expires_at: self._cache.pop(key, None) return None return entry.value def _set_cached(self, key: str, value: str) -> None: self._cache[key] = CacheEntry(value=value, expires_at=time.time() + CACHE_TTL_SECONDS) def _generation_kwargs(self, max_new_tokens: int) -> Dict[str, object]: return { "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.08, "eos_token_id": self._tokenizer.eos_token_id, "pad_token_id": self._tokenizer.eos_token_id, } @staticmethod def _clean_response(text: str) -> str: cleaned = text.strip() if not cleaned: return cleaned # Keep only the first assistant turn if the model starts fabricating dialogue. split_markers = ["\nUser:", "\nAssistant:", "\nSystem:"] for marker in split_markers: pos = cleaned.find(marker) if pos != -1: cleaned = cleaned[:pos].strip() # Trim generic wrap-up sections that tiny models often hallucinate. for marker in ["\nConclusion:", "\nFinal answer:"]: pos = cleaned.find(marker) if pos != -1: cleaned = cleaned[:pos].strip() cleaned = re.sub(r"\n{3,}", "\n\n", cleaned) # Avoid abrupt trailing fragments when the model ends mid-word/phrase. if cleaned and cleaned[-1] not in ".!?\"')]}": cleaned = cleaned.rstrip(" ,;:-") + "." return cleaned def clean_response(self, text: str) -> str: return self._clean_response(text) def generate(self, message: str, memory_context: str = "", tool_context: str = "") -> str: self.load() max_new_tokens = self.dynamic_token_budget(message) prompt = self._build_prompt(message, memory_context, tool_context) key = self._cache_key(prompt, max_new_tokens) cached = self._get_cached(key) if cached: return cached output = self._generator( prompt, return_full_text=False, **self._generation_kwargs(max_new_tokens), )[0]["generated_text"] # Continue generation when output appears cut off. attempts = 0 combined = output.strip() while attempts < 2 and self._looks_incomplete(combined, max_new_tokens): continuation_prompt = ( f"{prompt}\n{combined}\nContinue the same answer from where it stopped, " "without repeating earlier sentences:\n" ) extra = self._generator( continuation_prompt, max_new_tokens=160, do_sample=True, temperature=0.65, top_p=0.9, repetition_penalty=1.08, eos_token_id=self._tokenizer.eos_token_id, pad_token_id=self._tokenizer.eos_token_id, return_full_text=False, )[0]["generated_text"].strip() if not extra: break combined = f"{combined} {extra}".strip() attempts += 1 result = self._clean_response(combined) self._set_cached(key, result) return result def stream_generate(self, message: str, memory_context: str = "", tool_context: str = "") -> Iterator[str]: self.load() max_new_tokens = self.dynamic_token_budget(message) prompt = self._build_prompt(message, memory_context, tool_context) key = self._cache_key(prompt, max_new_tokens) cached = self._get_cached(key) if cached: yield cached return model = self._generator.model tokenizer = self._tokenizer inputs = tokenizer(prompt, return_tensors="pt") streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, **self._generation_kwargs(max_new_tokens), } worker = threading.Thread(target=model.generate, kwargs=generation_kwargs, daemon=True) worker.start() markers = ["\nUser:", "\nAssistant:", "\nSystem:", "User:", "Assistant:", "System:"] buffer = "" yielded_len = 0 stop_idx = -1 for piece in streamer: if not piece: continue buffer += piece # Find earliest marker in accumulated text (handles marker split across chunks). marker_positions = [buffer.find(m) for m in markers if buffer.find(m) != -1] if marker_positions: stop_idx = min(marker_positions) # Hold a short tail so markers crossing boundaries are still detected safely. safe_upto = len(buffer) - 20 if stop_idx == -1 else stop_idx if safe_upto > yielded_len: out = buffer[yielded_len:safe_upto] if out: yield out yielded_len = safe_upto if stop_idx != -1: break worker.join(timeout=0.1) if stop_idx == -1 and yielded_len < len(buffer): out = buffer[yielded_len:] if out: yield out truncated_final = buffer[:stop_idx] if stop_idx != -1 else buffer final_text = self._clean_response(truncated_final) if final_text: self._set_cached(key, final_text) _model_manager: Optional[ModelManager] = None def get_model_manager() -> ModelManager: global _model_manager if _model_manager is None: _model_manager = ModelManager() return _model_manager