import torch import torch.nn as nn from safetensors.torch import load_file from transformers import BertTokenizer class Gemma3ForConditionalGeneration(nn.Module): def __init__(self, vocab_size, embedding_dim=1344, num_heads=64, num_layers=48): super(Gemma3ForConditionalGeneration, self).__init__() self.token_embeddings = nn.Embedding(vocab_size, embedding_dim) self.transformer_layers = nn.ModuleList([ nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads) for _ in range(num_layers) ]) self.output_layer = nn.Linear(embedding_dim, vocab_size) def forward(self, input_ids): text_embeddings = self.token_embeddings(input_ids) for layer in self.transformer_layers: text_embeddings = layer(text_embeddings) return self.output_layer(text_embeddings) def load_model(model_path, vocab_size): model = Gemma3ForConditionalGeneration(vocab_size=vocab_size) model_weights = load_file(model_path) model.load_state_dict(model_weights, strict=False) model.eval() return model def generate_text(model, tokenizer, prompt, max_length=50, temperature=1.0): input_ids = tokenizer(prompt, return_tensors='pt').input_ids generated_ids = input_ids for _ in range(max_length): with torch.no_grad(): outputs = model(generated_ids) next_token_logits = outputs[:, -1, :] # Get the logits for the last token # Apply temperature next_token_logits = next_token_logits / temperature # Use softmax to get probabilities probabilities = torch.softmax(next_token_logits, dim=-1) # Sample from the distribution next_token = torch.multinomial(probabilities, num_samples=1) # Sample a token generated_ids = torch.cat((generated_ids, next_token.unsqueeze(0)), dim=1) # Append the predicted token generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) return generated_text if __name__ == "__main__": tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') vocab_size = 262208 // 4 model_path = './model.safetensors' # Replace with your model path model = load_model(model_path, vocab_size) prompt = "What is the capital of France?" # Generate text based on the prompt with a specified temperature generated_text = generate_text(model, tokenizer, prompt, temperature=0.7) # Adjust temperature as needed print("Generated Text:", generated_text)