Kora-API / app /services /model_service.py
ProfessorCEO's picture
Deploy API
e251d62
"""Transformers + PyTorch text generation service for KORA."""
from __future__ import annotations
import asyncio
import logging
from threading import Thread
from uuid import uuid4
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from app.utils.config import Settings
logger = logging.getLogger(__name__)
MIN_TEMPERATURE = 1e-5
class ModelService:
"""Manages Transformers model lifecycle and CPU-safe generation."""
def __init__(self, settings: Settings) -> None:
self.settings = settings
self._model = None
self._tokenizer = None
self._startup_lock = asyncio.Lock()
self._generation_lock = asyncio.Lock()
async def startup(self) -> None:
"""Initialize model engine once per process."""
if self._model is not None and self._tokenizer is not None:
return
async with self._startup_lock:
if self._model is not None and self._tokenizer is not None:
return
if self.settings.torch_num_threads > 0:
torch.set_num_threads(self.settings.torch_num_threads)
if self.settings.torch_num_interop_threads > 0:
torch.set_num_interop_threads(self.settings.torch_num_interop_threads)
logger.info("Loading model via Transformers on CPU: %s", self.settings.model_name)
self._tokenizer = AutoTokenizer.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
trust_remote_code=self.settings.trust_remote_code,
use_fast=True,
)
if self._tokenizer.pad_token is None and self._tokenizer.eos_token is not None:
self._tokenizer.pad_token = self._tokenizer.eos_token
from peft import PeftModel
logger.info("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
trust_remote_code=self.settings.trust_remote_code,
torch_dtype=torch.float32,
device_map="cpu"
)
logger.info("Applying PEFT adapter...")
self._model = PeftModel.from_pretrained(base_model, self.settings.model_name)
self._model.eval()
logger.info("CPU model and tokenizer initialized")
async def shutdown(self) -> None:
"""Graceful shutdown hook."""
self._model = None
self._tokenizer = None
def _build_prompt(self, messages: list[dict[str, str]]) -> str:
"""Render OpenAI-style messages into a model prompt."""
if self._tokenizer is None:
raise RuntimeError("Tokenizer is not initialized")
try:
return self._tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
except (AttributeError, TypeError, ValueError):
logger.warning("Chat template not supported, using fallback format", exc_info=True)
# Fallback for tokenizer templates that may not support message format.
lines = [f"{m['role'].upper()}: {m['content']}" for m in messages]
lines.append("ASSISTANT:")
return "\n".join(lines)
def _build_generation_kwargs(
self,
*,
temperature: float,
top_p: float,
max_tokens: int,
) -> dict:
if self._tokenizer is None:
raise RuntimeError("Tokenizer is not initialized")
clamped_temperature = max(0.0, float(temperature))
clamped_top_p = min(max(float(top_p), 0.0), 1.0)
do_sample = clamped_temperature >= MIN_TEMPERATURE
kwargs = {
"max_new_tokens": max(1, int(max_tokens)),
"do_sample": do_sample,
"pad_token_id": self._tokenizer.pad_token_id,
"eos_token_id": self._tokenizer.eos_token_id,
}
if do_sample:
kwargs["temperature"] = max(clamped_temperature, MIN_TEMPERATURE)
kwargs["top_p"] = clamped_top_p if clamped_top_p > 0.0 else 1.0
return kwargs
def _tokenize_prompt(self, prompt: str) -> dict:
if self._tokenizer is None:
raise RuntimeError("Tokenizer is not initialized")
return self._tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=self.settings.max_input_tokens,
)
async def _ensure_ready(self) -> None:
if self._model is None or self._tokenizer is None:
await self.startup()
if self._model is None or self._tokenizer is None:
raise RuntimeError("Model service failed to initialize")
async def stream_text(
self,
messages: list[dict[str, str]],
*,
temperature: float,
top_p: float,
max_tokens: int,
):
"""Yield incremental token deltas for SSE streaming."""
await self._ensure_ready()
request_id = f"chatcmpl-{uuid4().hex}"
prompt = self._build_prompt(messages)
inputs = self._tokenize_prompt(prompt)
generation_kwargs = self._build_generation_kwargs(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
)
streamer = TextIteratorStreamer(
self._tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_error: Exception | None = None
def run_generation() -> None:
nonlocal generation_error
try:
with torch.inference_mode():
self._model.generate(
**inputs,
streamer=streamer,
**generation_kwargs,
)
except (RuntimeError, ValueError, TypeError) as exc: # pragma: no cover - runtime guard
generation_error = exc
logger.exception("Streaming generation failed")
async with self._generation_lock:
worker = Thread(target=run_generation)
worker.start()
iterator = iter(streamer)
while True:
token = await asyncio.to_thread(next, iterator, None)
if token is None:
break
if generation_error is not None:
break
yield request_id, token
await asyncio.to_thread(worker.join)
if generation_error is not None:
raise RuntimeError("Streaming generation failed") from generation_error
async def complete_text(
self,
messages: list[dict[str, str]],
*,
temperature: float,
top_p: float,
max_tokens: int,
) -> tuple[str, str]:
"""Generate the final full completion in non-stream mode."""
await self._ensure_ready()
request_id = f"chatcmpl-{uuid4().hex}"
prompt = self._build_prompt(messages)
inputs = self._tokenize_prompt(prompt)
generation_kwargs = self._build_generation_kwargs(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
)
async with self._generation_lock:
output_ids = await asyncio.to_thread(
self._generate_sync,
inputs,
generation_kwargs,
)
input_ids = inputs.get("input_ids")
if input_ids is None:
raise RuntimeError("Tokenization failed to produce input_ids")
prompt_token_count = int(input_ids.shape[-1])
generated_ids = output_ids[0][prompt_token_count:]
final_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return request_id, final_text
def _generate_sync(self, inputs: dict, generation_kwargs: dict):
if self._model is None:
raise RuntimeError("Model is not initialized")
with torch.inference_mode():
return self._model.generate(**inputs, **generation_kwargs)