Spaces:
Running
Running
| """Transformers + PyTorch text generation service for KORA.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| from threading import Thread | |
| from uuid import uuid4 | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from app.utils.config import Settings | |
| logger = logging.getLogger(__name__) | |
| MIN_TEMPERATURE = 1e-5 | |
| class ModelService: | |
| """Manages Transformers model lifecycle and CPU-safe generation.""" | |
| def __init__(self, settings: Settings) -> None: | |
| self.settings = settings | |
| self._model = None | |
| self._tokenizer = None | |
| self._startup_lock = asyncio.Lock() | |
| self._generation_lock = asyncio.Lock() | |
| async def startup(self) -> None: | |
| """Initialize model engine once per process.""" | |
| if self._model is not None and self._tokenizer is not None: | |
| return | |
| async with self._startup_lock: | |
| if self._model is not None and self._tokenizer is not None: | |
| return | |
| if self.settings.torch_num_threads > 0: | |
| torch.set_num_threads(self.settings.torch_num_threads) | |
| if self.settings.torch_num_interop_threads > 0: | |
| torch.set_num_interop_threads(self.settings.torch_num_interop_threads) | |
| logger.info("Loading model via Transformers on CPU: %s", self.settings.model_name) | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| "microsoft/Phi-3-mini-4k-instruct", | |
| trust_remote_code=self.settings.trust_remote_code, | |
| use_fast=True, | |
| ) | |
| if self._tokenizer.pad_token is None and self._tokenizer.eos_token is not None: | |
| self._tokenizer.pad_token = self._tokenizer.eos_token | |
| from peft import PeftModel | |
| logger.info("Loading base model...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/Phi-3-mini-4k-instruct", | |
| trust_remote_code=self.settings.trust_remote_code, | |
| torch_dtype=torch.float32, | |
| device_map="cpu" | |
| ) | |
| logger.info("Applying PEFT adapter...") | |
| self._model = PeftModel.from_pretrained(base_model, self.settings.model_name) | |
| self._model.eval() | |
| logger.info("CPU model and tokenizer initialized") | |
| async def shutdown(self) -> None: | |
| """Graceful shutdown hook.""" | |
| self._model = None | |
| self._tokenizer = None | |
| def _build_prompt(self, messages: list[dict[str, str]]) -> str: | |
| """Render OpenAI-style messages into a model prompt.""" | |
| if self._tokenizer is None: | |
| raise RuntimeError("Tokenizer is not initialized") | |
| try: | |
| return self._tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| except (AttributeError, TypeError, ValueError): | |
| logger.warning("Chat template not supported, using fallback format", exc_info=True) | |
| # Fallback for tokenizer templates that may not support message format. | |
| lines = [f"{m['role'].upper()}: {m['content']}" for m in messages] | |
| lines.append("ASSISTANT:") | |
| return "\n".join(lines) | |
| def _build_generation_kwargs( | |
| self, | |
| *, | |
| temperature: float, | |
| top_p: float, | |
| max_tokens: int, | |
| ) -> dict: | |
| if self._tokenizer is None: | |
| raise RuntimeError("Tokenizer is not initialized") | |
| clamped_temperature = max(0.0, float(temperature)) | |
| clamped_top_p = min(max(float(top_p), 0.0), 1.0) | |
| do_sample = clamped_temperature >= MIN_TEMPERATURE | |
| kwargs = { | |
| "max_new_tokens": max(1, int(max_tokens)), | |
| "do_sample": do_sample, | |
| "pad_token_id": self._tokenizer.pad_token_id, | |
| "eos_token_id": self._tokenizer.eos_token_id, | |
| } | |
| if do_sample: | |
| kwargs["temperature"] = max(clamped_temperature, MIN_TEMPERATURE) | |
| kwargs["top_p"] = clamped_top_p if clamped_top_p > 0.0 else 1.0 | |
| return kwargs | |
| def _tokenize_prompt(self, prompt: str) -> dict: | |
| if self._tokenizer is None: | |
| raise RuntimeError("Tokenizer is not initialized") | |
| return self._tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=self.settings.max_input_tokens, | |
| ) | |
| async def _ensure_ready(self) -> None: | |
| if self._model is None or self._tokenizer is None: | |
| await self.startup() | |
| if self._model is None or self._tokenizer is None: | |
| raise RuntimeError("Model service failed to initialize") | |
| async def stream_text( | |
| self, | |
| messages: list[dict[str, str]], | |
| *, | |
| temperature: float, | |
| top_p: float, | |
| max_tokens: int, | |
| ): | |
| """Yield incremental token deltas for SSE streaming.""" | |
| await self._ensure_ready() | |
| request_id = f"chatcmpl-{uuid4().hex}" | |
| prompt = self._build_prompt(messages) | |
| inputs = self._tokenize_prompt(prompt) | |
| generation_kwargs = self._build_generation_kwargs( | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| ) | |
| streamer = TextIteratorStreamer( | |
| self._tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| generation_error: Exception | None = None | |
| def run_generation() -> None: | |
| nonlocal generation_error | |
| try: | |
| with torch.inference_mode(): | |
| self._model.generate( | |
| **inputs, | |
| streamer=streamer, | |
| **generation_kwargs, | |
| ) | |
| except (RuntimeError, ValueError, TypeError) as exc: # pragma: no cover - runtime guard | |
| generation_error = exc | |
| logger.exception("Streaming generation failed") | |
| async with self._generation_lock: | |
| worker = Thread(target=run_generation) | |
| worker.start() | |
| iterator = iter(streamer) | |
| while True: | |
| token = await asyncio.to_thread(next, iterator, None) | |
| if token is None: | |
| break | |
| if generation_error is not None: | |
| break | |
| yield request_id, token | |
| await asyncio.to_thread(worker.join) | |
| if generation_error is not None: | |
| raise RuntimeError("Streaming generation failed") from generation_error | |
| async def complete_text( | |
| self, | |
| messages: list[dict[str, str]], | |
| *, | |
| temperature: float, | |
| top_p: float, | |
| max_tokens: int, | |
| ) -> tuple[str, str]: | |
| """Generate the final full completion in non-stream mode.""" | |
| await self._ensure_ready() | |
| request_id = f"chatcmpl-{uuid4().hex}" | |
| prompt = self._build_prompt(messages) | |
| inputs = self._tokenize_prompt(prompt) | |
| generation_kwargs = self._build_generation_kwargs( | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| ) | |
| async with self._generation_lock: | |
| output_ids = await asyncio.to_thread( | |
| self._generate_sync, | |
| inputs, | |
| generation_kwargs, | |
| ) | |
| input_ids = inputs.get("input_ids") | |
| if input_ids is None: | |
| raise RuntimeError("Tokenization failed to produce input_ids") | |
| prompt_token_count = int(input_ids.shape[-1]) | |
| generated_ids = output_ids[0][prompt_token_count:] | |
| final_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
| return request_id, final_text | |
| def _generate_sync(self, inputs: dict, generation_kwargs: dict): | |
| if self._model is None: | |
| raise RuntimeError("Model is not initialized") | |
| with torch.inference_mode(): | |
| return self._model.generate(**inputs, **generation_kwargs) | |