File size: 5,513 Bytes
101858b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

import torch
import torch.nn as nn
from typing import Optional, Any, Dict, List

class SharedModelConfig:
    def __init__(self, model_name=None, model_type=None, device="cuda", use_fp16=True):
        self.model_name = model_name
        self.model_type = model_type
        self.device = device
        self.use_fp16 = use_fp16

_GLOBAL_SHARED_MODEL = None
_GLOBAL_SHARED_TOKENIZER = None
_GLOBAL_SHARED_ADAPTER = None

class SharedModel:
    """
    🗜️ Zero-Overhead Shared Model Backbone
    Provides a unified interface for multiple neural subsystems 
    to share the same VRAM-resident weights.
    """
    def __init__(self, config: Optional[SharedModelConfig] = None):
        self.config = config
        self._model = _GLOBAL_SHARED_MODEL
        self._tokenizer = _GLOBAL_SHARED_TOKENIZER

    @classmethod
    def register(cls, model: nn.Module, tokenizer: Any = None):
        """Register the primary model as the shared backbone."""
        global _GLOBAL_SHARED_MODEL, _GLOBAL_SHARED_TOKENIZER, _GLOBAL_SHARED_ADAPTER
        _GLOBAL_SHARED_MODEL = model
        _GLOBAL_SHARED_TOKENIZER = tokenizer
        _GLOBAL_SHARED_ADAPTER = SharedBackboneAdapter(model, tokenizer)

    def get_model(self) -> Optional[nn.Module]:
        return _GLOBAL_SHARED_MODEL

    def get_tokenizer(self) -> Any:
        return _GLOBAL_SHARED_TOKENIZER

    def get_adapter(self) -> Optional["SharedBackboneAdapter"]:
        return _GLOBAL_SHARED_ADAPTER
        
    @property
    def model(self) -> Optional[nn.Module]:
        return _GLOBAL_SHARED_MODEL


class SharedBackboneAdapter:
    """
    Read-only bridge to the already-loaded Qwen backbone.
    Subsystems use this instead of creating private layers or fallback models.
    """
    def __init__(self, model: nn.Module, tokenizer: Any = None):
        self.model = model
        self.tokenizer = tokenizer

    @property
    def device(self) -> torch.device:
        return next(self.model.parameters()).device

    @property
    def dtype(self) -> torch.dtype:
        return next(self.model.parameters()).dtype

    @property
    def config(self):
        return getattr(self.model, "config", None)

    @property
    def hidden_size(self) -> int:
        return int(getattr(self.config, "hidden_size", 1024))

    def get_model(self) -> nn.Module:
        return self.model

    def get_tokenizer(self) -> Any:
        return self.tokenizer

    def get_text_model(self) -> nn.Module:
        model = getattr(self.model, "model", self.model)
        return getattr(model, "language_model", model)

    def get_layers(self) -> List[nn.Module]:
        text_model = self.get_text_model()
        return list(getattr(text_model, "layers", []))

    def get_layer(self, index: int = 0) -> Optional[nn.Module]:
        layers = self.get_layers()
        if not layers:
            return None
        index = max(0, min(int(index), len(layers) - 1))
        return layers[index]

    def get_embedding_layer(self) -> Optional[nn.Module]:
        if hasattr(self.model, "get_input_embeddings"):
            return self.model.get_input_embeddings()
        text_model = self.get_text_model()
        return getattr(text_model, "embed_tokens", None)

    def encode_text(self, text: str, max_length: int = 256, use_hidden: bool = False) -> torch.Tensor:
        if self.tokenizer is None:
            raise RuntimeError("Shared tokenizer is not registered.")
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=max_length,
            add_special_tokens=True,
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            if use_hidden:
                outputs = self.model(**inputs, output_hidden_states=True, use_cache=False)
                hidden = getattr(outputs, "hidden_states", None)
                if hidden:
                    states = hidden[-1]
                else:
                    states = getattr(outputs, "last_hidden_state", None)
                if states is not None:
                    return states.mean(dim=1).to(dtype=self.dtype)
            embed = self.get_embedding_layer()
            if embed is None:
                raise RuntimeError("Shared model has no input embedding layer.")
            token_embeddings = embed(inputs["input_ids"])
            attention_mask = inputs.get("attention_mask")
            if attention_mask is None:
                return token_embeddings.mean(dim=1)
            mask = attention_mask.unsqueeze(-1).to(dtype=token_embeddings.dtype)
            return (token_embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1)

    def summary(self) -> Dict[str, Any]:
        return {
            "model_type": type(self.model).__name__,
            "tokenizer_type": type(self.tokenizer).__name__ if self.tokenizer is not None else None,
            "hidden_size": self.hidden_size,
            "num_layers": len(self.get_layers()),
            "device": str(self.device),
            "dtype": str(self.dtype),
        }

def get_shared_model() -> Optional[nn.Module]:
    return _GLOBAL_SHARED_MODEL

def get_shared_tokenizer() -> Any:
    return _GLOBAL_SHARED_TOKENIZER

def get_shared_adapter() -> Optional[SharedBackboneAdapter]:
    return _GLOBAL_SHARED_ADAPTER

def register_shared_model(model: nn.Module, tokenizer: Any = None) -> SharedBackboneAdapter:
    SharedModel.register(model, tokenizer)
    return _GLOBAL_SHARED_ADAPTER