import torch import torch.nn as nn from typing import Optional, Any, Dict, List class SharedModelConfig: def __init__(self, model_name=None, model_type=None, device="cuda", use_fp16=True): self.model_name = model_name self.model_type = model_type self.device = device self.use_fp16 = use_fp16 _GLOBAL_SHARED_MODEL = None _GLOBAL_SHARED_TOKENIZER = None _GLOBAL_SHARED_ADAPTER = None class SharedModel: """ 🗜️ Zero-Overhead Shared Model Backbone Provides a unified interface for multiple neural subsystems to share the same VRAM-resident weights. """ def __init__(self, config: Optional[SharedModelConfig] = None): self.config = config self._model = _GLOBAL_SHARED_MODEL self._tokenizer = _GLOBAL_SHARED_TOKENIZER @classmethod def register(cls, model: nn.Module, tokenizer: Any = None): """Register the primary model as the shared backbone.""" global _GLOBAL_SHARED_MODEL, _GLOBAL_SHARED_TOKENIZER, _GLOBAL_SHARED_ADAPTER _GLOBAL_SHARED_MODEL = model _GLOBAL_SHARED_TOKENIZER = tokenizer _GLOBAL_SHARED_ADAPTER = SharedBackboneAdapter(model, tokenizer) def get_model(self) -> Optional[nn.Module]: return _GLOBAL_SHARED_MODEL def get_tokenizer(self) -> Any: return _GLOBAL_SHARED_TOKENIZER def get_adapter(self) -> Optional["SharedBackboneAdapter"]: return _GLOBAL_SHARED_ADAPTER @property def model(self) -> Optional[nn.Module]: return _GLOBAL_SHARED_MODEL class SharedBackboneAdapter: """ Read-only bridge to the already-loaded Qwen backbone. Subsystems use this instead of creating private layers or fallback models. """ def __init__(self, model: nn.Module, tokenizer: Any = None): self.model = model self.tokenizer = tokenizer @property def device(self) -> torch.device: return next(self.model.parameters()).device @property def dtype(self) -> torch.dtype: return next(self.model.parameters()).dtype @property def config(self): return getattr(self.model, "config", None) @property def hidden_size(self) -> int: return int(getattr(self.config, "hidden_size", 1024)) def get_model(self) -> nn.Module: return self.model def get_tokenizer(self) -> Any: return self.tokenizer def get_text_model(self) -> nn.Module: model = getattr(self.model, "model", self.model) return getattr(model, "language_model", model) def get_layers(self) -> List[nn.Module]: text_model = self.get_text_model() return list(getattr(text_model, "layers", [])) def get_layer(self, index: int = 0) -> Optional[nn.Module]: layers = self.get_layers() if not layers: return None index = max(0, min(int(index), len(layers) - 1)) return layers[index] def get_embedding_layer(self) -> Optional[nn.Module]: if hasattr(self.model, "get_input_embeddings"): return self.model.get_input_embeddings() text_model = self.get_text_model() return getattr(text_model, "embed_tokens", None) def encode_text(self, text: str, max_length: int = 256, use_hidden: bool = False) -> torch.Tensor: if self.tokenizer is None: raise RuntimeError("Shared tokenizer is not registered.") inputs = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=max_length, add_special_tokens=True, ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): if use_hidden: outputs = self.model(**inputs, output_hidden_states=True, use_cache=False) hidden = getattr(outputs, "hidden_states", None) if hidden: states = hidden[-1] else: states = getattr(outputs, "last_hidden_state", None) if states is not None: return states.mean(dim=1).to(dtype=self.dtype) embed = self.get_embedding_layer() if embed is None: raise RuntimeError("Shared model has no input embedding layer.") token_embeddings = embed(inputs["input_ids"]) attention_mask = inputs.get("attention_mask") if attention_mask is None: return token_embeddings.mean(dim=1) mask = attention_mask.unsqueeze(-1).to(dtype=token_embeddings.dtype) return (token_embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1) def summary(self) -> Dict[str, Any]: return { "model_type": type(self.model).__name__, "tokenizer_type": type(self.tokenizer).__name__ if self.tokenizer is not None else None, "hidden_size": self.hidden_size, "num_layers": len(self.get_layers()), "device": str(self.device), "dtype": str(self.dtype), } def get_shared_model() -> Optional[nn.Module]: return _GLOBAL_SHARED_MODEL def get_shared_tokenizer() -> Any: return _GLOBAL_SHARED_TOKENIZER def get_shared_adapter() -> Optional[SharedBackboneAdapter]: return _GLOBAL_SHARED_ADAPTER def register_shared_model(model: nn.Module, tokenizer: Any = None) -> SharedBackboneAdapter: SharedModel.register(model, tokenizer) return _GLOBAL_SHARED_ADAPTER