smb-structure-llama3-8b-multi-objective / modeling_smb_unstructured.py
anon-9421's picture
Upload folder using huggingface_hub
ef48861 verified
"""SMB-Structure text-only language model wrapper.
Minimal inference wrapper around Qwen3 / Qwen2 / Llama / Phi backbones.
Patient-timeline text is fed in directly; hidden states are exposed via
the standard HF forward signature (output_hidden_states=True).
Usage:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"<anonymous-hf-org>/smb-structure-qwen3-1.7b",
trust_remote_code=True,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(
"<anonymous-hf-org>/smb-structure-qwen3-1.7b"
)
"""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import (
PreTrainedModel,
PretrainedConfig,
AutoConfig,
AutoModelForCausalLM,
GenerationMixin,
LlamaForCausalLM,
Qwen2ForCausalLM,
PhiForCausalLM,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
# Try to import Qwen3, fall back to AutoModelForCausalLM
try:
from transformers import Qwen3ForCausalLM
HAS_QWEN3 = True
except ImportError:
HAS_QWEN3 = False
Qwen3ForCausalLM = None
# =============================================================================
# LLM BACKEND MAPPING
# =============================================================================
def get_llm_class(model_type: str):
"""Get LLM class based on model type string."""
model_type = model_type.lower()
if "qwen3" in model_type:
if HAS_QWEN3:
return Qwen3ForCausalLM
else:
return AutoModelForCausalLM
elif "qwen2" in model_type or "qwen" in model_type:
return Qwen2ForCausalLM
elif "llama-3" in model_type or "llama3" in model_type or "meta-llama-3" in model_type:
return AutoModelForCausalLM # Llama 3 uses AutoModelForCausalLM
elif "llama" in model_type or "vicuna" in model_type:
return LlamaForCausalLM
elif "phi" in model_type:
return PhiForCausalLM
else:
return AutoModelForCausalLM
# =============================================================================
# CHAT TEMPLATES
# =============================================================================
@dataclass
class Qwen3Template:
"""ChatML template for Qwen3 models."""
system_prompt: str = "You are a helpful assistant."
def format_chat(self, prompt: str, system: str = None) -> str:
sys = system or self.system_prompt
return (
f"<|im_start|>system\n{sys}<|im_end|>\n"
f"<|im_start|>user\n{prompt}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
@property
def stop_tokens(self) -> List[str]:
return ["<|im_end|>", "<|endoftext|>"]
@dataclass
class Llama3Template:
"""Template for Llama 3 models."""
system_prompt: str = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
def format_chat(self, prompt: str, system: str = None) -> str:
sys = system or self.system_prompt
return f"{sys} USER: {prompt} ASSISTANT:"
@property
def stop_tokens(self) -> List[str]:
return ["<|end_of_text|>", "<|eot_id|>"]
@dataclass
class Qwen2Template:
"""Template for Qwen2 base models."""
system_prompt: str = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
def format_chat(self, prompt: str, system: str = None) -> str:
sys = system or self.system_prompt
return f"{sys} USER: {prompt} ASSISTANT:"
@property
def stop_tokens(self) -> List[str]:
return ["<|endoftext|>", "<|im_end|>"]
@dataclass
class LlamaTemplate:
"""Template for Llama/Vicuna models."""
system_prompt: str = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
def format_chat(self, prompt: str, system: str = None) -> str:
sys = system or self.system_prompt
return f"{sys} USER: {prompt} ASSISTANT:"
@property
def stop_tokens(self) -> List[str]:
return ["</s>"]
@dataclass
class PhiTemplate:
"""Template for Phi models."""
system_prompt: str = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
def format_chat(self, prompt: str, system: str = None) -> str:
sys = system or self.system_prompt
return f"{sys} USER: {prompt} ASSISTANT:"
@property
def stop_tokens(self) -> List[str]:
return ["<|endoftext|>"]
def get_template(model_type: str):
"""Get chat template based on model type."""
model_type = model_type.lower()
if "qwen3" in model_type:
return Qwen3Template()
elif "qwen2" in model_type or "qwen" in model_type:
return Qwen2Template()
elif "llama-3" in model_type or "llama3" in model_type or "meta-llama-3" in model_type:
return Llama3Template()
elif "llama" in model_type or "vicuna" in model_type:
return LlamaTemplate()
elif "phi" in model_type:
return PhiTemplate()
else:
return Qwen3Template() # Default
# =============================================================================
# CONFIGURATION
# =============================================================================
class SMBUnstructuredConfig(PretrainedConfig):
"""Configuration for SMB Unstructured text-only model."""
model_type = "smb_unstructured"
def __init__(
self,
llm_model_name_or_path: str = "",
tokenizer_name_or_path: str = None,
text_config: dict = None,
hidden_size: int = 2048,
vocab_size: int = 32000,
pad_token: str = None,
pad_token_id: int = None,
tokenizer_padding_side: str = "right",
tokenizer_model_max_length: int = 2048,
use_cache: bool = True,
**kwargs
):
self.llm_model_name_or_path = llm_model_name_or_path
self.tokenizer_name_or_path = tokenizer_name_or_path or llm_model_name_or_path
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.pad_token = pad_token
self.pad_token_id = pad_token_id
self.tokenizer_padding_side = tokenizer_padding_side
self.tokenizer_model_max_length = tokenizer_model_max_length
self.use_cache = use_cache
# Load text config
if text_config is not None:
if isinstance(text_config, dict):
self.text_config = AutoConfig.for_model(**text_config)
else:
self.text_config = text_config
else:
self.text_config = None
# Extract hidden_size and vocab_size from text_config
if self.text_config is not None:
self.hidden_size = getattr(self.text_config, "hidden_size", hidden_size)
self.vocab_size = getattr(self.text_config, "vocab_size", vocab_size)
super().__init__(**kwargs)
# =============================================================================
# MAIN MODEL
# =============================================================================
class SMBUnstructuredPreTrainedModel(PreTrainedModel):
"""Base class for SMB Unstructured models."""
config_class = SMBUnstructuredConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", 0.02)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
class SMBUnstructuredForCausalLM(SMBUnstructuredPreTrainedModel, GenerationMixin):
"""
SMB Unstructured text-only language model.
A minimal wrapper around the base LLM for inference.
"""
def __init__(self, config: SMBUnstructuredConfig):
super().__init__(config)
# Detect LLM type from text_config
if config.text_config is not None:
llm_type = getattr(config.text_config, "model_type", "qwen3")
else:
llm_type = "qwen3"
# Initialize language model
llm_class = get_llm_class(llm_type)
if llm_class == AutoModelForCausalLM:
self.language_model = llm_class.from_config(config.text_config)
else:
self.language_model = llm_class(config.text_config)
# Get chat template
self.template = get_template(llm_type)
self._llm_type = llm_type
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def get_decoder(self):
return self.language_model.get_decoder()
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def tie_weights(self):
return self.language_model.tie_weights()
def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None
) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(
new_num_tokens, pad_to_multiple_of
)
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.config.vocab_size = model_embeds.num_embeddings
return model_embeds
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""Forward pass - direct passthrough to language model."""
return self.language_model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache if use_cache is not None else self.config.use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
"""Generate text - direct passthrough to language model."""
# Handle inputs_embeds vs input_ids
if "inputs_embeds" not in kwargs and inputs is not None:
kwargs["input_ids"] = inputs
return self.language_model.generate(**kwargs)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
**kwargs
):
"""Prepare inputs for generation."""
return self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs
)
def chat(
self,
prompt: str,
tokenizer,
system_prompt: str = None,
max_new_tokens: int = 512,
temperature: float = 0.0,
top_p: float = 0.9,
top_k: int = 50,
do_sample: bool = None,
**kwargs,
) -> str:
"""
Chat interface for text generation.
Args:
prompt: User input prompt.
tokenizer: Tokenizer instance.
system_prompt: Optional system prompt override.
max_new_tokens: Maximum tokens to generate.
temperature: Sampling temperature (0 = greedy).
top_p: Nucleus sampling parameter (default 0.9).
top_k: Top-k sampling parameter (default 50).
do_sample: Whether to sample (auto-detected from temperature).
**kwargs: Additional generation arguments.
Returns:
Generated text response.
"""
# Format prompt with template
formatted_prompt = self.template.format_chat(prompt, system_prompt)
# Tokenize
inputs = tokenizer(formatted_prompt, return_tensors="pt")
input_ids = inputs.input_ids.to(self.device)
attention_mask = inputs.attention_mask.to(self.device)
input_length = input_ids.shape[1]
# Build stop token IDs
eos_token_ids = []
if tokenizer.eos_token_id is not None:
eos_token_ids.append(tokenizer.eos_token_id)
for token in self.template.stop_tokens:
token_id = tokenizer.convert_tokens_to_ids(token)
if token_id != tokenizer.unk_token_id and token_id not in eos_token_ids:
eos_token_ids.append(token_id)
# Determine sampling strategy
if do_sample is None:
do_sample = temperature > 0
# Build generation config
gen_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": max_new_tokens,
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
"eos_token_id": eos_token_ids if len(eos_token_ids) > 1 else (eos_token_ids[0] if eos_token_ids else None),
"use_cache": True,
}
if do_sample:
gen_kwargs["do_sample"] = True
gen_kwargs["temperature"] = temperature
gen_kwargs["top_p"] = top_p
gen_kwargs["top_k"] = top_k
else:
gen_kwargs["do_sample"] = False
# Add any additional kwargs (but don't override what we set)
for k, v in kwargs.items():
if k not in gen_kwargs:
gen_kwargs[k] = v
# Generate
with torch.inference_mode():
output_ids = self.language_model.generate(**gen_kwargs)
# Decode only new tokens
generated_ids = output_ids[:, input_length:]
response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
response = response.strip()
# Clean up stop tokens from response
for token in self.template.stop_tokens:
if response.endswith(token):
response = response[:-len(token)].strip()
return response
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
# =============================================================================
# REGISTER WITH AUTO CLASSES
# =============================================================================
AutoConfig.register("smb_unstructured", SMBUnstructuredConfig)
AutoModelForCausalLM.register(SMBUnstructuredConfig, SMBUnstructuredForCausalLM)