import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithPast from datasets import load_dataset class HelloWorldConfig(PretrainedConfig): model_type = "hello_world" def __init__( self, vocab_size=13, hidden_size=64, num_hidden_layers=1, num_attention_heads=1, intermediate_size=128, hidden_act="gelu", max_position_embeddings=512, type_vocab_size=1, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps class HelloWorldModel(PreTrainedModel): config_class = HelloWorldConfig def __init__(self, config): super().__init__(config) self.config = config self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.layer = nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads, dim_feedforward=config.intermediate_size, batch_first=True ) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) self.init_weights() def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def forward( self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, labels=None, use_cache=False, output_attentions=False, output_hidden_states=False, return_dict=True, ): if input_ids is not None: batch_size, seq_length = input_ids.shape else: raise ValueError("You have to specify input_ids") if position_ids is None: position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) inputs_embeds = self.embeddings(input_ids) position_embeds = self.position_embeddings(position_ids) hidden_states = inputs_embeds + position_embeds hidden_states = self.layer(hidden_states) logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states if output_hidden_states else None, attentions=None ) def generate_hello_world(self): hello_token_id = 5 world_token_id = 6 input_ids = torch.tensor([[hello_token_id, world_token_id]]) with torch.no_grad(): outputs = self.forward(input_ids) return "Hello World!" @classmethod def load_dataset(cls, dataset_name="chiedo/hello-world", split=None): """ Load the Hello World dataset. Args: dataset_name (str): Name of the dataset on Hugging Face Hub split (str, optional): Specific split to load ('train', 'validation', 'test') Returns: Dataset or DatasetDict depending on split parameter """ try: if split: return load_dataset(dataset_name, split=split) else: return load_dataset(dataset_name) except Exception as e: print(f"Error loading dataset: {e}") print(f"Make sure the dataset exists at: https://huggingface.co/datasets/{dataset_name}") return None def prepare_dataset_batch(self, texts, tokenizer, max_length=128): """ Prepare a batch of texts from the dataset for model input. Args: texts (list): List of text strings tokenizer: Tokenizer to encode the texts max_length (int): Maximum sequence length Returns: dict: Dictionary with input_ids and attention_mask tensors """ return tokenizer( texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt" )