AI-Agent / model.py
Valtry's picture
Upload 2 files
d70c8a7 verified
import hashlib
import os
import re
import threading
import time
from dataclasses import dataclass
from typing import Dict, Iterator, Optional
import torch
from transformers import TextIteratorStreamer, pipeline
DEFAULT_MODEL_ID = os.getenv("MODEL_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
CACHE_TTL_SECONDS = int(os.getenv("RESPONSE_CACHE_TTL", "600"))
@dataclass
class CacheEntry:
value: str
expires_at: float
class ModelManager:
def __init__(self, model_id: str = DEFAULT_MODEL_ID) -> None:
self.model_id = model_id
self._generator = None
self._tokenizer = None
self._lock = threading.Lock()
self._cache: Dict[str, CacheEntry] = {}
def load(self) -> None:
if self._generator is not None:
return
with self._lock:
if self._generator is not None:
return
try:
self._generator = pipeline(
task="text-generation",
model=self.model_id,
tokenizer=self.model_id,
device=-1,
model_kwargs={
"torch_dtype": torch.float32,
},
)
except Exception:
# Final fallback for constrained runtimes with strict model loading behavior.
self._generator = pipeline(
task="text-generation",
model=self.model_id,
tokenizer=self.model_id,
device=-1,
)
self._tokenizer = self._generator.tokenizer
@staticmethod
def dynamic_token_budget(message: str) -> int:
words = len(message.split())
lower = message.lower()
complexity_hints = (
"explain",
"compare",
"analyze",
"step by step",
"architecture",
"strategy",
"detailed",
)
if words <= 12 and not any(hint in lower for hint in complexity_hints):
return 120
if words <= 35:
return 360
return 720
@staticmethod
def _looks_incomplete(text: str, max_new_tokens: int) -> bool:
stripped = text.strip()
if not stripped:
return True
likely_truncated = len(stripped.split()) >= int(max_new_tokens * 0.75)
clean_endings = (".", "!", "?", "\"", "'", ")", "]", "}")
has_clean_ending = stripped.endswith(clean_endings)
return likely_truncated and not has_clean_ending
@staticmethod
def _build_prompt(message: str, memory_context: str, tool_context: str) -> str:
system = (
"You are a friendly, helpful general AI assistant. "
"Use a warm, respectful tone and practical wording. "
"Be concise when possible, but complete. "
"Use prior context if relevant. If tools are provided, ground your answer in them. "
"Output only the assistant answer. Do not write role labels like 'User:' or 'Assistant:'. "
"Do not add unrelated sections such as 'Conclusion:' unless the user explicitly asked for them."
)
parts = [f"System: {system}"]
if memory_context:
parts.append(f"Conversation memory:\n{memory_context}")
if tool_context:
parts.append(f"Tool results:\n{tool_context}")
parts.append(f"User: {message}")
parts.append("Assistant:")
return "\n\n".join(parts)
def _cache_key(self, prompt: str, max_new_tokens: int) -> str:
material = f"{self.model_id}|{max_new_tokens}|{prompt}".encode("utf-8")
return hashlib.sha256(material).hexdigest()
def _get_cached(self, key: str) -> Optional[str]:
entry = self._cache.get(key)
if not entry:
return None
if time.time() > entry.expires_at:
self._cache.pop(key, None)
return None
return entry.value
def _set_cached(self, key: str, value: str) -> None:
self._cache[key] = CacheEntry(value=value, expires_at=time.time() + CACHE_TTL_SECONDS)
def _generation_kwargs(self, max_new_tokens: int) -> Dict[str, object]:
return {
"max_new_tokens": max_new_tokens,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.08,
"eos_token_id": self._tokenizer.eos_token_id,
"pad_token_id": self._tokenizer.eos_token_id,
}
@staticmethod
def _clean_response(text: str) -> str:
cleaned = text.strip()
if not cleaned:
return cleaned
# Keep only the first assistant turn if the model starts fabricating dialogue.
split_markers = ["\nUser:", "\nAssistant:", "\nSystem:"]
for marker in split_markers:
pos = cleaned.find(marker)
if pos != -1:
cleaned = cleaned[:pos].strip()
# Trim generic wrap-up sections that tiny models often hallucinate.
for marker in ["\nConclusion:", "\nFinal answer:"]:
pos = cleaned.find(marker)
if pos != -1:
cleaned = cleaned[:pos].strip()
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
# Avoid abrupt trailing fragments when the model ends mid-word/phrase.
if cleaned and cleaned[-1] not in ".!?\"')]}":
cleaned = cleaned.rstrip(" ,;:-") + "."
return cleaned
def clean_response(self, text: str) -> str:
return self._clean_response(text)
def generate(self, message: str, memory_context: str = "", tool_context: str = "") -> str:
self.load()
max_new_tokens = self.dynamic_token_budget(message)
prompt = self._build_prompt(message, memory_context, tool_context)
key = self._cache_key(prompt, max_new_tokens)
cached = self._get_cached(key)
if cached:
return cached
output = self._generator(
prompt,
return_full_text=False,
**self._generation_kwargs(max_new_tokens),
)[0]["generated_text"]
# Continue generation when output appears cut off.
attempts = 0
combined = output.strip()
while attempts < 2 and self._looks_incomplete(combined, max_new_tokens):
continuation_prompt = (
f"{prompt}\n{combined}\nContinue the same answer from where it stopped, "
"without repeating earlier sentences:\n"
)
extra = self._generator(
continuation_prompt,
max_new_tokens=160,
do_sample=True,
temperature=0.65,
top_p=0.9,
repetition_penalty=1.08,
eos_token_id=self._tokenizer.eos_token_id,
pad_token_id=self._tokenizer.eos_token_id,
return_full_text=False,
)[0]["generated_text"].strip()
if not extra:
break
combined = f"{combined} {extra}".strip()
attempts += 1
result = self._clean_response(combined)
self._set_cached(key, result)
return result
def stream_generate(self, message: str, memory_context: str = "", tool_context: str = "") -> Iterator[str]:
self.load()
max_new_tokens = self.dynamic_token_budget(message)
prompt = self._build_prompt(message, memory_context, tool_context)
key = self._cache_key(prompt, max_new_tokens)
cached = self._get_cached(key)
if cached:
yield cached
return
model = self._generator.model
tokenizer = self._tokenizer
inputs = tokenizer(prompt, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
"streamer": streamer,
**self._generation_kwargs(max_new_tokens),
}
worker = threading.Thread(target=model.generate, kwargs=generation_kwargs, daemon=True)
worker.start()
markers = ["\nUser:", "\nAssistant:", "\nSystem:", "User:", "Assistant:", "System:"]
buffer = ""
yielded_len = 0
stop_idx = -1
for piece in streamer:
if not piece:
continue
buffer += piece
# Find earliest marker in accumulated text (handles marker split across chunks).
marker_positions = [buffer.find(m) for m in markers if buffer.find(m) != -1]
if marker_positions:
stop_idx = min(marker_positions)
# Hold a short tail so markers crossing boundaries are still detected safely.
safe_upto = len(buffer) - 20 if stop_idx == -1 else stop_idx
if safe_upto > yielded_len:
out = buffer[yielded_len:safe_upto]
if out:
yield out
yielded_len = safe_upto
if stop_idx != -1:
break
worker.join(timeout=0.1)
if stop_idx == -1 and yielded_len < len(buffer):
out = buffer[yielded_len:]
if out:
yield out
truncated_final = buffer[:stop_idx] if stop_idx != -1 else buffer
final_text = self._clean_response(truncated_final)
if final_text:
self._set_cached(key, final_text)
_model_manager: Optional[ModelManager] = None
def get_model_manager() -> ModelManager:
global _model_manager
if _model_manager is None:
_model_manager = ModelManager()
return _model_manager