Spaces:
Sleeping
Sleeping
| # model.py - Optimized version | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| from functools import lru_cache | |
| import os | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| # Global variables to store loaded model | |
| _tokenizer = None | |
| _model = None | |
| _model_loading = False | |
| _model_loaded = False | |
| def get_model_config(): | |
| """Cache model configuration""" | |
| return { | |
| "model_id": "deepseek-ai/deepseek-coder-1.3b-instruct", | |
| "torch_dtype": torch.bfloat16, | |
| "device_map": "auto", | |
| "trust_remote_code": True, | |
| # Add these optimizations | |
| "low_cpu_mem_usage": True, | |
| "use_cache": True, | |
| } | |
| def load_model_sync(): | |
| """Synchronous model loading with optimizations""" | |
| global _tokenizer, _model, _model_loaded | |
| if _model_loaded: | |
| return _tokenizer, _model | |
| config = get_model_config() | |
| model_id = config["model_id"] | |
| logger.info(f"π§ Loading model {model_id}...") | |
| try: | |
| # Set cache directory to avoid re-downloading | |
| cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Load tokenizer first (faster) | |
| logger.info("π Loading tokenizer...") | |
| _tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=config["trust_remote_code"], | |
| cache_dir=cache_dir, | |
| use_fast=True, # Use fast tokenizer if available | |
| ) | |
| # Load model with optimizations | |
| logger.info("π§ Loading model...") | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| trust_remote_code=config["trust_remote_code"], | |
| torch_dtype=config["torch_dtype"], | |
| device_map=config["device_map"], | |
| low_cpu_mem_usage=config["low_cpu_mem_usage"], | |
| cache_dir=cache_dir, | |
| offload_folder="offload", | |
| offload_state_dict=True | |
| ) | |
| # Set to evaluation mode | |
| _model.eval() | |
| _model_loaded = True | |
| logger.info("β Model loaded successfully!") | |
| return _tokenizer, _model | |
| except Exception as e: | |
| logger.error(f"β Failed to load model: {e}") | |
| raise | |
| async def load_model_async(): | |
| """Asynchronous model loading""" | |
| global _model_loading | |
| if _model_loaded: | |
| return _tokenizer, _model | |
| if _model_loading: | |
| # Wait for ongoing loading to complete | |
| while _model_loading and not _model_loaded: | |
| await asyncio.sleep(0.1) | |
| return _tokenizer, _model | |
| _model_loading = True | |
| try: | |
| # Run model loading in thread pool to avoid blocking | |
| loop = asyncio.get_event_loop() | |
| with ThreadPoolExecutor(max_workers=1) as executor: | |
| tokenizer, model = await loop.run_in_executor( | |
| executor, load_model_sync | |
| ) | |
| return tokenizer, model | |
| finally: | |
| _model_loading = False | |
| def get_model(): | |
| """Get the loaded model (for synchronous access)""" | |
| if not _model_loaded: | |
| return load_model_sync() | |
| return _tokenizer, _model | |
| def is_model_loaded(): | |
| """Check if model is loaded""" | |
| return _model_loaded | |
| def get_model_info(): | |
| """Get model information without loading""" | |
| config = get_model_config() | |
| return { | |
| "model_id": config["model_id"], | |
| "loaded": _model_loaded, | |
| "loading": _model_loading, | |
| } | |