| import hashlib
|
| import os
|
| import re
|
| import threading
|
| import time
|
| from dataclasses import dataclass
|
| from typing import Dict, Iterator, Optional
|
|
|
| import torch
|
| from transformers import TextIteratorStreamer, pipeline
|
|
|
|
|
| DEFAULT_MODEL_ID = os.getenv("MODEL_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
| CACHE_TTL_SECONDS = int(os.getenv("RESPONSE_CACHE_TTL", "600"))
|
|
|
|
|
| @dataclass
|
| class CacheEntry:
|
| value: str
|
| expires_at: float
|
|
|
|
|
| class ModelManager:
|
| def __init__(self, model_id: str = DEFAULT_MODEL_ID) -> None:
|
| self.model_id = model_id
|
| self._generator = None
|
| self._tokenizer = None
|
| self._lock = threading.Lock()
|
| self._cache: Dict[str, CacheEntry] = {}
|
|
|
| def load(self) -> None:
|
| if self._generator is not None:
|
| return
|
|
|
| with self._lock:
|
| if self._generator is not None:
|
| return
|
|
|
| try:
|
| self._generator = pipeline(
|
| task="text-generation",
|
| model=self.model_id,
|
| tokenizer=self.model_id,
|
| device=-1,
|
| model_kwargs={
|
| "torch_dtype": torch.float32,
|
| },
|
| )
|
| except Exception:
|
|
|
| self._generator = pipeline(
|
| task="text-generation",
|
| model=self.model_id,
|
| tokenizer=self.model_id,
|
| device=-1,
|
| )
|
| self._tokenizer = self._generator.tokenizer
|
|
|
| @staticmethod
|
| def dynamic_token_budget(message: str) -> int:
|
| words = len(message.split())
|
| lower = message.lower()
|
| complexity_hints = (
|
| "explain",
|
| "compare",
|
| "analyze",
|
| "step by step",
|
| "architecture",
|
| "strategy",
|
| "detailed",
|
| )
|
|
|
| if words <= 12 and not any(hint in lower for hint in complexity_hints):
|
| return 120
|
| if words <= 35:
|
| return 360
|
| return 720
|
|
|
| @staticmethod
|
| def _looks_incomplete(text: str, max_new_tokens: int) -> bool:
|
| stripped = text.strip()
|
| if not stripped:
|
| return True
|
|
|
| likely_truncated = len(stripped.split()) >= int(max_new_tokens * 0.75)
|
| clean_endings = (".", "!", "?", "\"", "'", ")", "]", "}")
|
| has_clean_ending = stripped.endswith(clean_endings)
|
| return likely_truncated and not has_clean_ending
|
|
|
| @staticmethod
|
| def _build_prompt(message: str, memory_context: str, tool_context: str) -> str:
|
| system = (
|
| "You are a friendly, helpful general AI assistant. "
|
| "Use a warm, respectful tone and practical wording. "
|
| "Be concise when possible, but complete. "
|
| "Use prior context if relevant. If tools are provided, ground your answer in them. "
|
| "Output only the assistant answer. Do not write role labels like 'User:' or 'Assistant:'. "
|
| "Do not add unrelated sections such as 'Conclusion:' unless the user explicitly asked for them."
|
| )
|
|
|
| parts = [f"System: {system}"]
|
| if memory_context:
|
| parts.append(f"Conversation memory:\n{memory_context}")
|
| if tool_context:
|
| parts.append(f"Tool results:\n{tool_context}")
|
| parts.append(f"User: {message}")
|
| parts.append("Assistant:")
|
| return "\n\n".join(parts)
|
|
|
| def _cache_key(self, prompt: str, max_new_tokens: int) -> str:
|
| material = f"{self.model_id}|{max_new_tokens}|{prompt}".encode("utf-8")
|
| return hashlib.sha256(material).hexdigest()
|
|
|
| def _get_cached(self, key: str) -> Optional[str]:
|
| entry = self._cache.get(key)
|
| if not entry:
|
| return None
|
| if time.time() > entry.expires_at:
|
| self._cache.pop(key, None)
|
| return None
|
| return entry.value
|
|
|
| def _set_cached(self, key: str, value: str) -> None:
|
| self._cache[key] = CacheEntry(value=value, expires_at=time.time() + CACHE_TTL_SECONDS)
|
|
|
| def _generation_kwargs(self, max_new_tokens: int) -> Dict[str, object]:
|
| return {
|
| "max_new_tokens": max_new_tokens,
|
| "do_sample": True,
|
| "temperature": 0.7,
|
| "top_p": 0.9,
|
| "repetition_penalty": 1.08,
|
| "eos_token_id": self._tokenizer.eos_token_id,
|
| "pad_token_id": self._tokenizer.eos_token_id,
|
| }
|
|
|
| @staticmethod
|
| def _clean_response(text: str) -> str:
|
| cleaned = text.strip()
|
| if not cleaned:
|
| return cleaned
|
|
|
|
|
| split_markers = ["\nUser:", "\nAssistant:", "\nSystem:"]
|
| for marker in split_markers:
|
| pos = cleaned.find(marker)
|
| if pos != -1:
|
| cleaned = cleaned[:pos].strip()
|
|
|
|
|
| for marker in ["\nConclusion:", "\nFinal answer:"]:
|
| pos = cleaned.find(marker)
|
| if pos != -1:
|
| cleaned = cleaned[:pos].strip()
|
|
|
| cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
|
|
|
|
|
| if cleaned and cleaned[-1] not in ".!?\"')]}":
|
| cleaned = cleaned.rstrip(" ,;:-") + "."
|
|
|
| return cleaned
|
|
|
| def clean_response(self, text: str) -> str:
|
| return self._clean_response(text)
|
|
|
| def generate(self, message: str, memory_context: str = "", tool_context: str = "") -> str:
|
| self.load()
|
| max_new_tokens = self.dynamic_token_budget(message)
|
| prompt = self._build_prompt(message, memory_context, tool_context)
|
|
|
| key = self._cache_key(prompt, max_new_tokens)
|
| cached = self._get_cached(key)
|
| if cached:
|
| return cached
|
|
|
| output = self._generator(
|
| prompt,
|
| return_full_text=False,
|
| **self._generation_kwargs(max_new_tokens),
|
| )[0]["generated_text"]
|
|
|
|
|
| attempts = 0
|
| combined = output.strip()
|
| while attempts < 2 and self._looks_incomplete(combined, max_new_tokens):
|
| continuation_prompt = (
|
| f"{prompt}\n{combined}\nContinue the same answer from where it stopped, "
|
| "without repeating earlier sentences:\n"
|
| )
|
| extra = self._generator(
|
| continuation_prompt,
|
| max_new_tokens=160,
|
| do_sample=True,
|
| temperature=0.65,
|
| top_p=0.9,
|
| repetition_penalty=1.08,
|
| eos_token_id=self._tokenizer.eos_token_id,
|
| pad_token_id=self._tokenizer.eos_token_id,
|
| return_full_text=False,
|
| )[0]["generated_text"].strip()
|
|
|
| if not extra:
|
| break
|
|
|
| combined = f"{combined} {extra}".strip()
|
| attempts += 1
|
|
|
| result = self._clean_response(combined)
|
| self._set_cached(key, result)
|
| return result
|
|
|
| def stream_generate(self, message: str, memory_context: str = "", tool_context: str = "") -> Iterator[str]:
|
| self.load()
|
| max_new_tokens = self.dynamic_token_budget(message)
|
| prompt = self._build_prompt(message, memory_context, tool_context)
|
|
|
| key = self._cache_key(prompt, max_new_tokens)
|
| cached = self._get_cached(key)
|
| if cached:
|
| yield cached
|
| return
|
|
|
| model = self._generator.model
|
| tokenizer = self._tokenizer
|
|
|
| inputs = tokenizer(prompt, return_tensors="pt")
|
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
|
| generation_kwargs = {
|
| **inputs,
|
| "streamer": streamer,
|
| **self._generation_kwargs(max_new_tokens),
|
| }
|
|
|
| worker = threading.Thread(target=model.generate, kwargs=generation_kwargs, daemon=True)
|
| worker.start()
|
|
|
| markers = ["\nUser:", "\nAssistant:", "\nSystem:", "User:", "Assistant:", "System:"]
|
| buffer = ""
|
| yielded_len = 0
|
| stop_idx = -1
|
|
|
| for piece in streamer:
|
| if not piece:
|
| continue
|
|
|
| buffer += piece
|
|
|
|
|
| marker_positions = [buffer.find(m) for m in markers if buffer.find(m) != -1]
|
| if marker_positions:
|
| stop_idx = min(marker_positions)
|
|
|
|
|
| safe_upto = len(buffer) - 20 if stop_idx == -1 else stop_idx
|
| if safe_upto > yielded_len:
|
| out = buffer[yielded_len:safe_upto]
|
| if out:
|
| yield out
|
| yielded_len = safe_upto
|
|
|
| if stop_idx != -1:
|
| break
|
|
|
| worker.join(timeout=0.1)
|
|
|
| if stop_idx == -1 and yielded_len < len(buffer):
|
| out = buffer[yielded_len:]
|
| if out:
|
| yield out
|
|
|
| truncated_final = buffer[:stop_idx] if stop_idx != -1 else buffer
|
| final_text = self._clean_response(truncated_final)
|
| if final_text:
|
| self._set_cached(key, final_text)
|
|
|
|
|
| _model_manager: Optional[ModelManager] = None
|
|
|
|
|
| def get_model_manager() -> ModelManager:
|
| global _model_manager
|
| if _model_manager is None:
|
| _model_manager = ModelManager()
|
| return _model_manager
|
|
|