Spaces:
Paused
Paused
| from abc import ABC, abstractmethod | |
| from fastapi import HTTPException | |
| from pydantic import BaseModel | |
| import torch | |
| import time | |
| import os | |
| # Try to import vLLM | |
| try: | |
| from vllm import LLM, SamplingParams | |
| VLLM_AVAILABLE = True | |
| except ImportError: | |
| VLLM_AVAILABLE = False | |
| # Import transformers | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| class GenerationRequest(BaseModel): | |
| prompt: str | |
| max_new_tokens: int = 100 | |
| temperature: float = 0.7 | |
| top_p: float = 0.9 | |
| top_k: int = 50 | |
| repetition_penalty: float = 1.0 | |
| do_sample: bool = True | |
| enable_thinking: bool = False | |
| class GenerationResponse(BaseModel): | |
| generated_text: str | |
| input_tokens: int | |
| output_tokens: int | |
| inference_time: float | |
| model_name: str | |
| class ModelManager(ABC): | |
| def __init__(self, model_name: str, backend_type: str): | |
| self.model_name = model_name | |
| self.backend_type = backend_type | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self.is_loaded = False | |
| def load_model(self): | |
| """Load the model and tokenizer""" | |
| pass | |
| def generate_text(self, request: GenerationRequest) -> GenerationResponse: | |
| """Generate text using the loaded model""" | |
| pass | |
| class ModelTransformersManager(ModelManager): | |
| def __init__(self, model_name: str): | |
| super().__init__(model_name, "Transformers") | |
| def load_model(self): | |
| """Load the model and tokenizer""" | |
| try: | |
| print(f"(iii) Loading model: {self.model_name}") | |
| # Determine device | |
| if torch.cuda.is_available(): | |
| self.device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| self.device = "mps" | |
| else: | |
| self.device = "cpu" | |
| print(f"Using device: {self.device}") | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| # Load model with appropriate settings for the device | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| low_cpu_mem_usage=True | |
| ).to(self.device) | |
| # Set pad token if not exists | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.is_loaded = True | |
| print("(iii) Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}") | |
| def generate_text(self, request: GenerationRequest) -> GenerationResponse: | |
| """Generate text using the loaded model""" | |
| if not self.is_loaded: | |
| raise HTTPException(status_code=500, detail="Model not loaded") | |
| try: | |
| start_time = time.time() | |
| # Tokenize input | |
| messages = [{"role": "user", "content": request.prompt}] | |
| inputs = self.tokenizer.apply_chat_template( | |
| messages, | |
| enable_thinking=request.enable_thinking, | |
| add_generation_prompt=False, | |
| tokenize=True, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Decode back to text | |
| text_inputs = self.tokenizer.decode(inputs[0], skip_special_tokens=False) | |
| print(f"(ddd) Text inputs:{text_inputs}") | |
| input_tokens = inputs.shape[1] | |
| print(f"(ddd) Input tokens: {input_tokens}") | |
| # Generate text | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| inputs, | |
| max_new_tokens=request.max_new_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| top_k=request.top_k, | |
| repetition_penalty=request.repetition_penalty, | |
| do_sample=request.do_sample, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| # Decode output | |
| generated_text = self.tokenizer.decode( | |
| outputs[0], | |
| skip_special_tokens=True | |
| ) | |
| # Remove input prompt from generated text | |
| if generated_text.startswith(request.prompt): | |
| generated_text = generated_text[len(request.prompt):].strip() | |
| output_tokens = outputs.shape[1] - input_tokens | |
| inference_time = time.time() - start_time | |
| return GenerationResponse( | |
| generated_text=generated_text, | |
| input_tokens=input_tokens, | |
| output_tokens=output_tokens, | |
| inference_time=inference_time, | |
| model_name=self.model_name | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
| class ModelVllmManager(ModelManager): | |
| def __init__(self, model_name: str): | |
| super().__init__(model_name, "VLLM") | |
| if not VLLM_AVAILABLE: | |
| raise ImportError("vLLM is not installed. Please install it with: pip install vllm") | |
| def load_model(self): | |
| """Load the model using vLLM""" | |
| try: | |
| print(f"(iii) Loading model with vLLM: {self.model_name}") | |
| # vLLM requires CUDA | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("vLLM requires CUDA support") | |
| self.device = "cuda" | |
| print(f"Using device: {self.device}") | |
| # Load model with vLLM | |
| self.model = LLM( | |
| model=self.model_name, | |
| dtype="auto", | |
| trust_remote_code=True | |
| ) | |
| # Load tokenizer for preprocessing | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| # Set pad token if not exists | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.is_loaded = True | |
| print("(iii) Model loaded successfully with vLLM!") | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}") | |
| def generate_text(self, request: GenerationRequest) -> GenerationResponse: | |
| """Generate text using the loaded vLLM model""" | |
| if not self.is_loaded: | |
| raise HTTPException(status_code=500, detail="Model not loaded") | |
| try: | |
| start_time = time.time() | |
| # Prepare sampling parameters | |
| sampling_params = SamplingParams( | |
| max_tokens=request.max_new_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| top_k=request.top_k, | |
| repetition_penalty=request.repetition_penalty, | |
| stop_token_ids=[self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id else None | |
| ) | |
| # Tokenize input for counting tokens | |
| messages = [{"role": "user", "content": request.prompt}] | |
| inputs_for_tokenization = self.tokenizer.apply_chat_template( | |
| messages, | |
| enable_thinking=request.enable_thinking, | |
| add_generation_prompt=False, | |
| tokenize=True, | |
| return_tensors="pt" | |
| ) | |
| input_tokens = inputs_for_tokenization.shape[1] | |
| print(f"(ddd) Input tokens: {input_tokens}") | |
| # Generate text with vLLM | |
| outputs = self.model.generate( | |
| request.prompt, | |
| sampling_params=sampling_params, | |
| use_tqdm=False | |
| ) | |
| # Extract generated text | |
| if outputs and len(outputs) > 0: | |
| generated_text = outputs[0].outputs[0].text | |
| else: | |
| generated_text = "" | |
| # Count output tokens | |
| output_tokens = len(outputs[0].outputs[0].token_ids) if outputs and len(outputs) > 0 else 0 | |
| inference_time = time.time() - start_time | |
| return GenerationResponse( | |
| generated_text=generated_text, | |
| input_tokens=input_tokens, | |
| output_tokens=output_tokens, | |
| inference_time=inference_time, | |
| model_name=self.model_name | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
| def create_model_manager(model_name: str, backend_type: str = "Transformers") -> ModelManager: | |
| """Factory function to create the appropriate model manager""" | |
| if backend_type.upper() == "VLLM": | |
| return ModelVllmManager(model_name) | |
| else: | |
| return ModelTransformersManager(model_name) |