"""Transformers + PyTorch text generation service for KORA.""" from __future__ import annotations import asyncio import logging from threading import Thread from uuid import uuid4 import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from app.utils.config import Settings logger = logging.getLogger(__name__) MIN_TEMPERATURE = 1e-5 class ModelService: """Manages Transformers model lifecycle and CPU-safe generation.""" def __init__(self, settings: Settings) -> None: self.settings = settings self._model = None self._tokenizer = None self._startup_lock = asyncio.Lock() self._generation_lock = asyncio.Lock() async def startup(self) -> None: """Initialize model engine once per process.""" if self._model is not None and self._tokenizer is not None: return async with self._startup_lock: if self._model is not None and self._tokenizer is not None: return if self.settings.torch_num_threads > 0: torch.set_num_threads(self.settings.torch_num_threads) if self.settings.torch_num_interop_threads > 0: torch.set_num_interop_threads(self.settings.torch_num_interop_threads) logger.info("Loading model via Transformers on CPU: %s", self.settings.model_name) self._tokenizer = AutoTokenizer.from_pretrained( "microsoft/Phi-3-mini-4k-instruct", trust_remote_code=self.settings.trust_remote_code, use_fast=True, ) if self._tokenizer.pad_token is None and self._tokenizer.eos_token is not None: self._tokenizer.pad_token = self._tokenizer.eos_token from peft import PeftModel logger.info("Loading base model...") base_model = AutoModelForCausalLM.from_pretrained( "microsoft/Phi-3-mini-4k-instruct", trust_remote_code=self.settings.trust_remote_code, torch_dtype=torch.float32, device_map="cpu" ) logger.info("Applying PEFT adapter...") self._model = PeftModel.from_pretrained(base_model, self.settings.model_name) self._model.eval() logger.info("CPU model and tokenizer initialized") async def shutdown(self) -> None: """Graceful shutdown hook.""" self._model = None self._tokenizer = None def _build_prompt(self, messages: list[dict[str, str]]) -> str: """Render OpenAI-style messages into a model prompt.""" if self._tokenizer is None: raise RuntimeError("Tokenizer is not initialized") try: return self._tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) except (AttributeError, TypeError, ValueError): logger.warning("Chat template not supported, using fallback format", exc_info=True) # Fallback for tokenizer templates that may not support message format. lines = [f"{m['role'].upper()}: {m['content']}" for m in messages] lines.append("ASSISTANT:") return "\n".join(lines) def _build_generation_kwargs( self, *, temperature: float, top_p: float, max_tokens: int, ) -> dict: if self._tokenizer is None: raise RuntimeError("Tokenizer is not initialized") clamped_temperature = max(0.0, float(temperature)) clamped_top_p = min(max(float(top_p), 0.0), 1.0) do_sample = clamped_temperature >= MIN_TEMPERATURE kwargs = { "max_new_tokens": max(1, int(max_tokens)), "do_sample": do_sample, "pad_token_id": self._tokenizer.pad_token_id, "eos_token_id": self._tokenizer.eos_token_id, } if do_sample: kwargs["temperature"] = max(clamped_temperature, MIN_TEMPERATURE) kwargs["top_p"] = clamped_top_p if clamped_top_p > 0.0 else 1.0 return kwargs def _tokenize_prompt(self, prompt: str) -> dict: if self._tokenizer is None: raise RuntimeError("Tokenizer is not initialized") return self._tokenizer( prompt, return_tensors="pt", truncation=True, max_length=self.settings.max_input_tokens, ) async def _ensure_ready(self) -> None: if self._model is None or self._tokenizer is None: await self.startup() if self._model is None or self._tokenizer is None: raise RuntimeError("Model service failed to initialize") async def stream_text( self, messages: list[dict[str, str]], *, temperature: float, top_p: float, max_tokens: int, ): """Yield incremental token deltas for SSE streaming.""" await self._ensure_ready() request_id = f"chatcmpl-{uuid4().hex}" prompt = self._build_prompt(messages) inputs = self._tokenize_prompt(prompt) generation_kwargs = self._build_generation_kwargs( temperature=temperature, top_p=top_p, max_tokens=max_tokens, ) streamer = TextIteratorStreamer( self._tokenizer, skip_prompt=True, skip_special_tokens=True, ) generation_error: Exception | None = None def run_generation() -> None: nonlocal generation_error try: with torch.inference_mode(): self._model.generate( **inputs, streamer=streamer, **generation_kwargs, ) except (RuntimeError, ValueError, TypeError) as exc: # pragma: no cover - runtime guard generation_error = exc logger.exception("Streaming generation failed") async with self._generation_lock: worker = Thread(target=run_generation) worker.start() iterator = iter(streamer) while True: token = await asyncio.to_thread(next, iterator, None) if token is None: break if generation_error is not None: break yield request_id, token await asyncio.to_thread(worker.join) if generation_error is not None: raise RuntimeError("Streaming generation failed") from generation_error async def complete_text( self, messages: list[dict[str, str]], *, temperature: float, top_p: float, max_tokens: int, ) -> tuple[str, str]: """Generate the final full completion in non-stream mode.""" await self._ensure_ready() request_id = f"chatcmpl-{uuid4().hex}" prompt = self._build_prompt(messages) inputs = self._tokenize_prompt(prompt) generation_kwargs = self._build_generation_kwargs( temperature=temperature, top_p=top_p, max_tokens=max_tokens, ) async with self._generation_lock: output_ids = await asyncio.to_thread( self._generate_sync, inputs, generation_kwargs, ) input_ids = inputs.get("input_ids") if input_ids is None: raise RuntimeError("Tokenization failed to produce input_ids") prompt_token_count = int(input_ids.shape[-1]) generated_ids = output_ids[0][prompt_token_count:] final_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True).strip() return request_id, final_text def _generate_sync(self, inputs: dict, generation_kwargs: dict): if self._model is None: raise RuntimeError("Model is not initialized") with torch.inference_mode(): return self._model.generate(**inputs, **generation_kwargs)