Spaces:
Running
Running
File size: 8,174 Bytes
e251d62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 | """Transformers + PyTorch text generation service for KORA."""
from __future__ import annotations
import asyncio
import logging
from threading import Thread
from uuid import uuid4
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from app.utils.config import Settings
logger = logging.getLogger(__name__)
MIN_TEMPERATURE = 1e-5
class ModelService:
"""Manages Transformers model lifecycle and CPU-safe generation."""
def __init__(self, settings: Settings) -> None:
self.settings = settings
self._model = None
self._tokenizer = None
self._startup_lock = asyncio.Lock()
self._generation_lock = asyncio.Lock()
async def startup(self) -> None:
"""Initialize model engine once per process."""
if self._model is not None and self._tokenizer is not None:
return
async with self._startup_lock:
if self._model is not None and self._tokenizer is not None:
return
if self.settings.torch_num_threads > 0:
torch.set_num_threads(self.settings.torch_num_threads)
if self.settings.torch_num_interop_threads > 0:
torch.set_num_interop_threads(self.settings.torch_num_interop_threads)
logger.info("Loading model via Transformers on CPU: %s", self.settings.model_name)
self._tokenizer = AutoTokenizer.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
trust_remote_code=self.settings.trust_remote_code,
use_fast=True,
)
if self._tokenizer.pad_token is None and self._tokenizer.eos_token is not None:
self._tokenizer.pad_token = self._tokenizer.eos_token
from peft import PeftModel
logger.info("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
trust_remote_code=self.settings.trust_remote_code,
torch_dtype=torch.float32,
device_map="cpu"
)
logger.info("Applying PEFT adapter...")
self._model = PeftModel.from_pretrained(base_model, self.settings.model_name)
self._model.eval()
logger.info("CPU model and tokenizer initialized")
async def shutdown(self) -> None:
"""Graceful shutdown hook."""
self._model = None
self._tokenizer = None
def _build_prompt(self, messages: list[dict[str, str]]) -> str:
"""Render OpenAI-style messages into a model prompt."""
if self._tokenizer is None:
raise RuntimeError("Tokenizer is not initialized")
try:
return self._tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
except (AttributeError, TypeError, ValueError):
logger.warning("Chat template not supported, using fallback format", exc_info=True)
# Fallback for tokenizer templates that may not support message format.
lines = [f"{m['role'].upper()}: {m['content']}" for m in messages]
lines.append("ASSISTANT:")
return "\n".join(lines)
def _build_generation_kwargs(
self,
*,
temperature: float,
top_p: float,
max_tokens: int,
) -> dict:
if self._tokenizer is None:
raise RuntimeError("Tokenizer is not initialized")
clamped_temperature = max(0.0, float(temperature))
clamped_top_p = min(max(float(top_p), 0.0), 1.0)
do_sample = clamped_temperature >= MIN_TEMPERATURE
kwargs = {
"max_new_tokens": max(1, int(max_tokens)),
"do_sample": do_sample,
"pad_token_id": self._tokenizer.pad_token_id,
"eos_token_id": self._tokenizer.eos_token_id,
}
if do_sample:
kwargs["temperature"] = max(clamped_temperature, MIN_TEMPERATURE)
kwargs["top_p"] = clamped_top_p if clamped_top_p > 0.0 else 1.0
return kwargs
def _tokenize_prompt(self, prompt: str) -> dict:
if self._tokenizer is None:
raise RuntimeError("Tokenizer is not initialized")
return self._tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=self.settings.max_input_tokens,
)
async def _ensure_ready(self) -> None:
if self._model is None or self._tokenizer is None:
await self.startup()
if self._model is None or self._tokenizer is None:
raise RuntimeError("Model service failed to initialize")
async def stream_text(
self,
messages: list[dict[str, str]],
*,
temperature: float,
top_p: float,
max_tokens: int,
):
"""Yield incremental token deltas for SSE streaming."""
await self._ensure_ready()
request_id = f"chatcmpl-{uuid4().hex}"
prompt = self._build_prompt(messages)
inputs = self._tokenize_prompt(prompt)
generation_kwargs = self._build_generation_kwargs(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
)
streamer = TextIteratorStreamer(
self._tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_error: Exception | None = None
def run_generation() -> None:
nonlocal generation_error
try:
with torch.inference_mode():
self._model.generate(
**inputs,
streamer=streamer,
**generation_kwargs,
)
except (RuntimeError, ValueError, TypeError) as exc: # pragma: no cover - runtime guard
generation_error = exc
logger.exception("Streaming generation failed")
async with self._generation_lock:
worker = Thread(target=run_generation)
worker.start()
iterator = iter(streamer)
while True:
token = await asyncio.to_thread(next, iterator, None)
if token is None:
break
if generation_error is not None:
break
yield request_id, token
await asyncio.to_thread(worker.join)
if generation_error is not None:
raise RuntimeError("Streaming generation failed") from generation_error
async def complete_text(
self,
messages: list[dict[str, str]],
*,
temperature: float,
top_p: float,
max_tokens: int,
) -> tuple[str, str]:
"""Generate the final full completion in non-stream mode."""
await self._ensure_ready()
request_id = f"chatcmpl-{uuid4().hex}"
prompt = self._build_prompt(messages)
inputs = self._tokenize_prompt(prompt)
generation_kwargs = self._build_generation_kwargs(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
)
async with self._generation_lock:
output_ids = await asyncio.to_thread(
self._generate_sync,
inputs,
generation_kwargs,
)
input_ids = inputs.get("input_ids")
if input_ids is None:
raise RuntimeError("Tokenization failed to produce input_ids")
prompt_token_count = int(input_ids.shape[-1])
generated_ids = output_ids[0][prompt_token_count:]
final_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return request_id, final_text
def _generate_sync(self, inputs: dict, generation_kwargs: dict):
if self._model is None:
raise RuntimeError("Model is not initialized")
with torch.inference_mode():
return self._model.generate(**inputs, **generation_kwargs)
|