BICORP commited on
Commit
af2099c
·
verified ·
1 Parent(s): d2e3c6d

Upload directory

Browse files
Files changed (1) hide show
  1. inference.py +61 -0
inference.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from safetensors.torch import load_file
4
+ from transformers import BertTokenizer
5
+
6
+ class Gemma3ForConditionalGeneration(nn.Module):
7
+ def __init__(self, vocab_size, embedding_dim=1344, num_heads=64, num_layers=48):
8
+ super(Gemma3ForConditionalGeneration, self).__init__()
9
+ self.token_embeddings = nn.Embedding(vocab_size, embedding_dim)
10
+ self.transformer_layers = nn.ModuleList([
11
+ nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads) for _ in range(num_layers)
12
+ ])
13
+ self.output_layer = nn.Linear(embedding_dim, vocab_size)
14
+
15
+ def forward(self, input_ids):
16
+ text_embeddings = self.token_embeddings(input_ids)
17
+ for layer in self.transformer_layers:
18
+ text_embeddings = layer(text_embeddings)
19
+ return self.output_layer(text_embeddings)
20
+
21
+ def load_model(model_path, vocab_size):
22
+ model = Gemma3ForConditionalGeneration(vocab_size=vocab_size)
23
+ model_weights = load_file(model_path)
24
+ model.load_state_dict(model_weights, strict=False)
25
+ model.eval()
26
+ return model
27
+
28
+ def generate_text(model, tokenizer, prompt, max_length=50, temperature=1.0):
29
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids
30
+ generated_ids = input_ids
31
+
32
+ for _ in range(max_length):
33
+ with torch.no_grad():
34
+ outputs = model(generated_ids)
35
+ next_token_logits = outputs[:, -1, :] # Get the logits for the last token
36
+
37
+ # Apply temperature
38
+ next_token_logits = next_token_logits / temperature
39
+
40
+ # Use softmax to get probabilities
41
+ probabilities = torch.softmax(next_token_logits, dim=-1)
42
+
43
+ # Sample from the distribution
44
+ next_token = torch.multinomial(probabilities, num_samples=1) # Sample a token
45
+
46
+ generated_ids = torch.cat((generated_ids, next_token.unsqueeze(0)), dim=1) # Append the predicted token
47
+
48
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
49
+ return generated_text
50
+
51
+ if __name__ == "__main__":
52
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
53
+ vocab_size = 262208 // 4
54
+ model_path = './model.safetensors' # Replace with your model path
55
+ model = load_model(model_path, vocab_size)
56
+
57
+ prompt = "What is the capital of France?"
58
+
59
+ # Generate text based on the prompt with a specified temperature
60
+ generated_text = generate_text(model, tokenizer, prompt, temperature=0.7) # Adjust temperature as needed
61
+ print("Generated Text:", generated_text)