SmollLLMRest / models /model_manager.py
LucaR84's picture
Create models/model_manager.py
865a2c0 verified
from abc import ABC, abstractmethod
from fastapi import HTTPException
from pydantic import BaseModel
import torch
import time
import os
# Try to import vLLM
try:
from vllm import LLM, SamplingParams
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
# Import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
class GenerationRequest(BaseModel):
prompt: str
max_new_tokens: int = 100
temperature: float = 0.7
top_p: float = 0.9
top_k: int = 50
repetition_penalty: float = 1.0
do_sample: bool = True
enable_thinking: bool = False
class GenerationResponse(BaseModel):
generated_text: str
input_tokens: int
output_tokens: int
inference_time: float
model_name: str
class ModelManager(ABC):
def __init__(self, model_name: str, backend_type: str):
self.model_name = model_name
self.backend_type = backend_type
self.model = None
self.tokenizer = None
self.device = None
self.is_loaded = False
@abstractmethod
def load_model(self):
"""Load the model and tokenizer"""
pass
@abstractmethod
def generate_text(self, request: GenerationRequest) -> GenerationResponse:
"""Generate text using the loaded model"""
pass
class ModelTransformersManager(ModelManager):
def __init__(self, model_name: str):
super().__init__(model_name, "Transformers")
def load_model(self):
"""Load the model and tokenizer"""
try:
print(f"(iii) Loading model: {self.model_name}")
# Determine device
if torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
print(f"Using device: {self.device}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# Load model with appropriate settings for the device
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
low_cpu_mem_usage=True
).to(self.device)
# Set pad token if not exists
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.is_loaded = True
print("(iii) Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {str(e)}")
raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
def generate_text(self, request: GenerationRequest) -> GenerationResponse:
"""Generate text using the loaded model"""
if not self.is_loaded:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
start_time = time.time()
# Tokenize input
messages = [{"role": "user", "content": request.prompt}]
inputs = self.tokenizer.apply_chat_template(
messages,
enable_thinking=request.enable_thinking,
add_generation_prompt=False,
tokenize=True,
return_tensors="pt"
).to(self.device)
# Decode back to text
text_inputs = self.tokenizer.decode(inputs[0], skip_special_tokens=False)
print(f"(ddd) Text inputs:{text_inputs}")
input_tokens = inputs.shape[1]
print(f"(ddd) Input tokens: {input_tokens}")
# Generate text
with torch.no_grad():
outputs = self.model.generate(
inputs,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
do_sample=request.do_sample,
pad_token_id=self.tokenizer.pad_token_id,
)
# Decode output
generated_text = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
# Remove input prompt from generated text
if generated_text.startswith(request.prompt):
generated_text = generated_text[len(request.prompt):].strip()
output_tokens = outputs.shape[1] - input_tokens
inference_time = time.time() - start_time
return GenerationResponse(
generated_text=generated_text,
input_tokens=input_tokens,
output_tokens=output_tokens,
inference_time=inference_time,
model_name=self.model_name
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
class ModelVllmManager(ModelManager):
def __init__(self, model_name: str):
super().__init__(model_name, "VLLM")
if not VLLM_AVAILABLE:
raise ImportError("vLLM is not installed. Please install it with: pip install vllm")
def load_model(self):
"""Load the model using vLLM"""
try:
print(f"(iii) Loading model with vLLM: {self.model_name}")
# vLLM requires CUDA
if not torch.cuda.is_available():
raise RuntimeError("vLLM requires CUDA support")
self.device = "cuda"
print(f"Using device: {self.device}")
# Load model with vLLM
self.model = LLM(
model=self.model_name,
dtype="auto",
trust_remote_code=True
)
# Load tokenizer for preprocessing
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# Set pad token if not exists
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.is_loaded = True
print("(iii) Model loaded successfully with vLLM!")
except Exception as e:
print(f"Error loading model: {str(e)}")
raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
def generate_text(self, request: GenerationRequest) -> GenerationResponse:
"""Generate text using the loaded vLLM model"""
if not self.is_loaded:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
start_time = time.time()
# Prepare sampling parameters
sampling_params = SamplingParams(
max_tokens=request.max_new_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
stop_token_ids=[self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id else None
)
# Tokenize input for counting tokens
messages = [{"role": "user", "content": request.prompt}]
inputs_for_tokenization = self.tokenizer.apply_chat_template(
messages,
enable_thinking=request.enable_thinking,
add_generation_prompt=False,
tokenize=True,
return_tensors="pt"
)
input_tokens = inputs_for_tokenization.shape[1]
print(f"(ddd) Input tokens: {input_tokens}")
# Generate text with vLLM
outputs = self.model.generate(
request.prompt,
sampling_params=sampling_params,
use_tqdm=False
)
# Extract generated text
if outputs and len(outputs) > 0:
generated_text = outputs[0].outputs[0].text
else:
generated_text = ""
# Count output tokens
output_tokens = len(outputs[0].outputs[0].token_ids) if outputs and len(outputs) > 0 else 0
inference_time = time.time() - start_time
return GenerationResponse(
generated_text=generated_text,
input_tokens=input_tokens,
output_tokens=output_tokens,
inference_time=inference_time,
model_name=self.model_name
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
def create_model_manager(model_name: str, backend_type: str = "Transformers") -> ModelManager:
"""Factory function to create the appropriate model manager"""
if backend_type.upper() == "VLLM":
return ModelVllmManager(model_name)
else:
return ModelTransformersManager(model_name)