Spaces:
Running on Zero
Running on Zero
File size: 2,380 Bytes
a067ada fab3ba1 a067ada fab3ba1 a067ada | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | """
Base model class defining the interface for all specialized models.
All model implementations inherit from BaseModel and implement
the abstract methods for loading and generating outputs.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import logging
logger = logging.getLogger(__name__)
class BaseModel(ABC):
"""Abstract base class for all model implementations."""
def __init__(self, model_name: str, model_path: Optional[str] = None) -> None:
"""
Initialize base model.
Args:
model_name: Name/identifier of the model
model_path: Path to model weights or config
"""
self.model_name = model_name
self.model_path = model_path
self.is_loaded = False
self.model = None
self.tokenizer = None
@abstractmethod
def load(self) -> None:
"""
Load the model and initialize it for inference.
Must be implemented by subclasses. Should set self.model
and update self.is_loaded flag.
Raises:
Exception: If model loading fails
"""
pass
@abstractmethod
def generate(self, **kwargs) -> Any:
"""
Generate output from the model.
Method signature varies by model type. Subclasses must implement.
Returns:
Model-specific output (string, dict, etc.)
"""
pass
def unload(self) -> None:
"""Unload model and free GPU VRAM."""
self.model = None
self.tokenizer = None
self.is_loaded = False
try:
import gc
import torch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
except Exception:
pass
logger.info(f"Model {self.model_name} unloaded")
def _validate_loaded(self) -> None:
"""Validate that model is loaded before inference."""
if not self.is_loaded or self.model is None:
raise RuntimeError(f"Model {self.model_name} is not loaded. Call load() first.")
def __repr__(self) -> str:
"""String representation of model."""
status = "loaded" if self.is_loaded else "not loaded"
return f"{self.__class__.__name__}(name={self.model_name}, status={status})"
|