| | import torch |
| | import torch.nn as nn |
| | from safetensors.torch import load_file |
| | from transformers import BertTokenizer |
| |
|
| | class Gemma3ForConditionalGeneration(nn.Module): |
| | def __init__(self, vocab_size, embedding_dim=1344, num_heads=64, num_layers=48): |
| | super(Gemma3ForConditionalGeneration, self).__init__() |
| | self.token_embeddings = nn.Embedding(vocab_size, embedding_dim) |
| | self.transformer_layers = nn.ModuleList([ |
| | nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads) for _ in range(num_layers) |
| | ]) |
| | self.output_layer = nn.Linear(embedding_dim, vocab_size) |
| |
|
| | def forward(self, input_ids): |
| | text_embeddings = self.token_embeddings(input_ids) |
| | for layer in self.transformer_layers: |
| | text_embeddings = layer(text_embeddings) |
| | return self.output_layer(text_embeddings) |
| |
|
| | def load_model(model_path, vocab_size): |
| | model = Gemma3ForConditionalGeneration(vocab_size=vocab_size) |
| | model_weights = load_file(model_path) |
| | model.load_state_dict(model_weights, strict=False) |
| | model.eval() |
| | return model |
| |
|
| | def generate_text(model, tokenizer, prompt, max_length=50, temperature=1.0): |
| | input_ids = tokenizer(prompt, return_tensors='pt').input_ids |
| | generated_ids = input_ids |
| |
|
| | for _ in range(max_length): |
| | with torch.no_grad(): |
| | outputs = model(generated_ids) |
| | next_token_logits = outputs[:, -1, :] |
| | |
| | |
| | next_token_logits = next_token_logits / temperature |
| | |
| | |
| | probabilities = torch.softmax(next_token_logits, dim=-1) |
| | |
| | |
| | next_token = torch.multinomial(probabilities, num_samples=1) |
| | |
| | generated_ids = torch.cat((generated_ids, next_token.unsqueeze(0)), dim=1) |
| |
|
| | generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| | return generated_text |
| |
|
| | if __name__ == "__main__": |
| | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
| | vocab_size = 262208 // 4 |
| | model_path = './model.safetensors' |
| | model = load_model(model_path, vocab_size) |
| |
|
| | prompt = "What is the capital of France?" |
| | |
| | |
| | generated_text = generate_text(model, tokenizer, prompt, temperature=0.7) |
| | print("Generated Text:", generated_text) |