Test-25 / inference.py
BICORP's picture
Upload directory
af2099c verified
raw
history blame
2.6 kB
import torch
import torch.nn as nn
from safetensors.torch import load_file
from transformers import BertTokenizer
class Gemma3ForConditionalGeneration(nn.Module):
def __init__(self, vocab_size, embedding_dim=1344, num_heads=64, num_layers=48):
super(Gemma3ForConditionalGeneration, self).__init__()
self.token_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.transformer_layers = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads) for _ in range(num_layers)
])
self.output_layer = nn.Linear(embedding_dim, vocab_size)
def forward(self, input_ids):
text_embeddings = self.token_embeddings(input_ids)
for layer in self.transformer_layers:
text_embeddings = layer(text_embeddings)
return self.output_layer(text_embeddings)
def load_model(model_path, vocab_size):
model = Gemma3ForConditionalGeneration(vocab_size=vocab_size)
model_weights = load_file(model_path)
model.load_state_dict(model_weights, strict=False)
model.eval()
return model
def generate_text(model, tokenizer, prompt, max_length=50, temperature=1.0):
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
generated_ids = input_ids
for _ in range(max_length):
with torch.no_grad():
outputs = model(generated_ids)
next_token_logits = outputs[:, -1, :] # Get the logits for the last token
# Apply temperature
next_token_logits = next_token_logits / temperature
# Use softmax to get probabilities
probabilities = torch.softmax(next_token_logits, dim=-1)
# Sample from the distribution
next_token = torch.multinomial(probabilities, num_samples=1) # Sample a token
generated_ids = torch.cat((generated_ids, next_token.unsqueeze(0)), dim=1) # Append the predicted token
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return generated_text
if __name__ == "__main__":
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vocab_size = 262208 // 4
model_path = './model.safetensors' # Replace with your model path
model = load_model(model_path, vocab_size)
prompt = "What is the capital of France?"
# Generate text based on the prompt with a specified temperature
generated_text = generate_text(model, tokenizer, prompt, temperature=0.7) # Adjust temperature as needed
print("Generated Text:", generated_text)