File size: 2,380 Bytes
a067ada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fab3ba1
a067ada
 
 
fab3ba1
 
 
 
 
 
 
 
 
a067ada
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Base model class defining the interface for all specialized models.

All model implementations inherit from BaseModel and implement
the abstract methods for loading and generating outputs.
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import logging

logger = logging.getLogger(__name__)


class BaseModel(ABC):
    """Abstract base class for all model implementations."""

    def __init__(self, model_name: str, model_path: Optional[str] = None) -> None:
        """
        Initialize base model.

        Args:
            model_name: Name/identifier of the model
            model_path: Path to model weights or config
        """
        self.model_name = model_name
        self.model_path = model_path
        self.is_loaded = False
        self.model = None
        self.tokenizer = None

    @abstractmethod
    def load(self) -> None:
        """
        Load the model and initialize it for inference.

        Must be implemented by subclasses. Should set self.model
        and update self.is_loaded flag.

        Raises:
            Exception: If model loading fails
        """
        pass

    @abstractmethod
    def generate(self, **kwargs) -> Any:
        """
        Generate output from the model.

        Method signature varies by model type. Subclasses must implement.

        Returns:
            Model-specific output (string, dict, etc.)
        """
        pass

    def unload(self) -> None:
        """Unload model and free GPU VRAM."""
        self.model = None
        self.tokenizer = None
        self.is_loaded = False
        try:
            import gc
            import torch
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
        except Exception:
            pass
        logger.info(f"Model {self.model_name} unloaded")

    def _validate_loaded(self) -> None:
        """Validate that model is loaded before inference."""
        if not self.is_loaded or self.model is None:
            raise RuntimeError(f"Model {self.model_name} is not loaded. Call load() first.")

    def __repr__(self) -> str:
        """String representation of model."""
        status = "loaded" if self.is_loaded else "not loaded"
        return f"{self.__class__.__name__}(name={self.model_name}, status={status})"