Phillnet-2 / shared_model /shared_model.py
ayjays132's picture
Upload 478 files
101858b verified
import torch
import torch.nn as nn
from typing import Optional, Any, Dict, List
class SharedModelConfig:
def __init__(self, model_name=None, model_type=None, device="cuda", use_fp16=True):
self.model_name = model_name
self.model_type = model_type
self.device = device
self.use_fp16 = use_fp16
_GLOBAL_SHARED_MODEL = None
_GLOBAL_SHARED_TOKENIZER = None
_GLOBAL_SHARED_ADAPTER = None
class SharedModel:
"""
🗜️ Zero-Overhead Shared Model Backbone
Provides a unified interface for multiple neural subsystems
to share the same VRAM-resident weights.
"""
def __init__(self, config: Optional[SharedModelConfig] = None):
self.config = config
self._model = _GLOBAL_SHARED_MODEL
self._tokenizer = _GLOBAL_SHARED_TOKENIZER
@classmethod
def register(cls, model: nn.Module, tokenizer: Any = None):
"""Register the primary model as the shared backbone."""
global _GLOBAL_SHARED_MODEL, _GLOBAL_SHARED_TOKENIZER, _GLOBAL_SHARED_ADAPTER
_GLOBAL_SHARED_MODEL = model
_GLOBAL_SHARED_TOKENIZER = tokenizer
_GLOBAL_SHARED_ADAPTER = SharedBackboneAdapter(model, tokenizer)
def get_model(self) -> Optional[nn.Module]:
return _GLOBAL_SHARED_MODEL
def get_tokenizer(self) -> Any:
return _GLOBAL_SHARED_TOKENIZER
def get_adapter(self) -> Optional["SharedBackboneAdapter"]:
return _GLOBAL_SHARED_ADAPTER
@property
def model(self) -> Optional[nn.Module]:
return _GLOBAL_SHARED_MODEL
class SharedBackboneAdapter:
"""
Read-only bridge to the already-loaded Qwen backbone.
Subsystems use this instead of creating private layers or fallback models.
"""
def __init__(self, model: nn.Module, tokenizer: Any = None):
self.model = model
self.tokenizer = tokenizer
@property
def device(self) -> torch.device:
return next(self.model.parameters()).device
@property
def dtype(self) -> torch.dtype:
return next(self.model.parameters()).dtype
@property
def config(self):
return getattr(self.model, "config", None)
@property
def hidden_size(self) -> int:
return int(getattr(self.config, "hidden_size", 1024))
def get_model(self) -> nn.Module:
return self.model
def get_tokenizer(self) -> Any:
return self.tokenizer
def get_text_model(self) -> nn.Module:
model = getattr(self.model, "model", self.model)
return getattr(model, "language_model", model)
def get_layers(self) -> List[nn.Module]:
text_model = self.get_text_model()
return list(getattr(text_model, "layers", []))
def get_layer(self, index: int = 0) -> Optional[nn.Module]:
layers = self.get_layers()
if not layers:
return None
index = max(0, min(int(index), len(layers) - 1))
return layers[index]
def get_embedding_layer(self) -> Optional[nn.Module]:
if hasattr(self.model, "get_input_embeddings"):
return self.model.get_input_embeddings()
text_model = self.get_text_model()
return getattr(text_model, "embed_tokens", None)
def encode_text(self, text: str, max_length: int = 256, use_hidden: bool = False) -> torch.Tensor:
if self.tokenizer is None:
raise RuntimeError("Shared tokenizer is not registered.")
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=max_length,
add_special_tokens=True,
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
if use_hidden:
outputs = self.model(**inputs, output_hidden_states=True, use_cache=False)
hidden = getattr(outputs, "hidden_states", None)
if hidden:
states = hidden[-1]
else:
states = getattr(outputs, "last_hidden_state", None)
if states is not None:
return states.mean(dim=1).to(dtype=self.dtype)
embed = self.get_embedding_layer()
if embed is None:
raise RuntimeError("Shared model has no input embedding layer.")
token_embeddings = embed(inputs["input_ids"])
attention_mask = inputs.get("attention_mask")
if attention_mask is None:
return token_embeddings.mean(dim=1)
mask = attention_mask.unsqueeze(-1).to(dtype=token_embeddings.dtype)
return (token_embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1)
def summary(self) -> Dict[str, Any]:
return {
"model_type": type(self.model).__name__,
"tokenizer_type": type(self.tokenizer).__name__ if self.tokenizer is not None else None,
"hidden_size": self.hidden_size,
"num_layers": len(self.get_layers()),
"device": str(self.device),
"dtype": str(self.dtype),
}
def get_shared_model() -> Optional[nn.Module]:
return _GLOBAL_SHARED_MODEL
def get_shared_tokenizer() -> Any:
return _GLOBAL_SHARED_TOKENIZER
def get_shared_adapter() -> Optional[SharedBackboneAdapter]:
return _GLOBAL_SHARED_ADAPTER
def register_shared_model(model: nn.Module, tokenizer: Any = None) -> SharedBackboneAdapter:
SharedModel.register(model, tokenizer)
return _GLOBAL_SHARED_ADAPTER