Spaces:
Sleeping
Sleeping
| import logging | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| from huggingface_hub import login | |
| from .config import Config | |
| import os | |
| logger = logging.getLogger(__name__) | |
| class ModelManager: | |
| def __init__(self, model_name: str): | |
| self.model_name = model_name | |
| self.tokenizer = None | |
| self.model = None | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Ensure offline mode is disabled | |
| os.environ['HF_HUB_OFFLINE'] = '0' | |
| os.environ['TRANSFORMERS_OFFLINE'] = '0' | |
| # Login to Hugging Face Hub | |
| if Config.HUGGING_FACE_TOKEN: | |
| logger.info("Logging in to Hugging Face Hub") | |
| try: | |
| login(token=Config.HUGGING_FACE_TOKEN, add_to_git_credential=False) | |
| logger.info("Successfully logged in to Hugging Face Hub") | |
| except Exception as e: | |
| logger.error(f"Failed to login to Hugging Face Hub: {str(e)}") | |
| raise | |
| # Initialize tokenizer and model | |
| self._init_tokenizer() | |
| self._init_model() | |
| def _init_tokenizer(self): | |
| """Initialize the tokenizer.""" | |
| try: | |
| logger.info(f"Loading tokenizer: {self.model_name}") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| token=Config.HUGGING_FACE_TOKEN, | |
| model_max_length=1024, # Limit max length to save memory | |
| trust_remote_code=True | |
| ) | |
| # Ensure we have the necessary special tokens | |
| special_tokens = { | |
| 'pad_token': '[PAD]', | |
| 'eos_token': '</s>', | |
| 'bos_token': '<s>' | |
| } | |
| self.tokenizer.add_special_tokens(special_tokens) | |
| logger.info("Tokenizer loaded successfully") | |
| logger.debug(f"Tokenizer vocabulary size: {len(self.tokenizer)}") | |
| except Exception as e: | |
| logger.error(f"Error loading tokenizer: {str(e)}") | |
| raise | |
| def _init_model(self): | |
| """Initialize the model.""" | |
| try: | |
| logger.info(f"Loading model: {self.model_name}") | |
| logger.info(f"Using device: {self.device}") | |
| # Load model with memory optimizations | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| device_map={"": self.device}, | |
| torch_dtype=torch.float32, | |
| token=Config.HUGGING_FACE_TOKEN, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| # Resize embeddings to match tokenizer | |
| self.model.resize_token_embeddings(len(self.tokenizer)) | |
| logger.info("Model loaded successfully") | |
| logger.debug(f"Model parameters: {sum(p.numel() for p in self.model.parameters())}") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| raise | |
| def generate_text(self, prompt: str, max_new_tokens: int = 512) -> str: | |
| """Generate text from prompt.""" | |
| try: | |
| logger.info("Starting text generation") | |
| logger.debug(f"Prompt length: {len(prompt)}") | |
| # Encode the prompt with reduced max length | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, # Reduced max length | |
| padding=True | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| logger.debug(f"Input tensor shape: {inputs['input_ids'].shape}") | |
| # Generate response with memory optimizations | |
| logger.info("Generating response") | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=Config.TEMPERATURE, | |
| top_p=Config.TOP_P, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| num_beams=1, # Disable beam search to save memory | |
| use_cache=True, # Enable KV cache for faster generation | |
| early_stopping=True | |
| ) | |
| # Clear CUDA cache after generation | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Decode and return the generated text | |
| generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = generated_text[len(prompt):].strip() | |
| logger.info("Text generation completed") | |
| logger.debug(f"Response length: {len(response)}") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error generating text: {str(e)}") | |
| logger.error(f"Error details: {type(e).__name__}") | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| raise | |