Ekrem-the-second commited on
Commit
01d8db4
·
verified ·
1 Parent(s): 6f8173d

Upload run_model.py via Colab

Browse files
Files changed (1) hide show
  1. run_model.py +112 -0
run_model.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import tiktoken
5
+ import os
6
+
7
+ # ==========================================
8
+ # SETTINGS
9
+ # ==========================================
10
+ model_path = "/content/yagiz_gpt_full_packaged.pt"
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ block_size = 512 # Context window size of the model
13
+
14
+ # ==========================================
15
+ # 1. LOAD PACKAGED MODEL
16
+ # ==========================================
17
+ print(f"Device: {device}")
18
+
19
+ if not os.path.exists(model_path):
20
+ raise FileNotFoundError(f"ERROR: File {model_path} not found. Please make sure the model is packaged correctly.")
21
+
22
+ print(f"Loading {model_path}...")
23
+
24
+ # MAGIC PART: No class definitions needed, just loading the TorchScript model.
25
+ try:
26
+ model = torch.jit.load(model_path, map_location=device)
27
+ model.eval()
28
+ print("Model loaded successfully!")
29
+ except Exception as e:
30
+ print(f"Failed to load the model: {e}")
31
+ exit()
32
+
33
+ # ==========================================
34
+ # 2. TOKENIZER SETUP
35
+ # ==========================================
36
+ # Using 'tiktoken' since the model was trained with GPT-2 tokenizer (vocab_size=50257)
37
+ try:
38
+ enc = tiktoken.get_encoding("gpt2")
39
+ except:
40
+ print("Tiktoken library missing. Installing...")
41
+ os.system("pip install tiktoken")
42
+ import tiktoken
43
+ enc = tiktoken.get_encoding("gpt2")
44
+
45
+ # Helper functions for encoding and decoding
46
+ encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
47
+ decode = lambda l: enc.decode(l)
48
+
49
+ # ==========================================
50
+ # 3. RESPONSE GENERATION FUNCTION
51
+ # ==========================================
52
+ def generate_response(prompt, max_new_tokens=100):
53
+ # 1. Convert text to tensor indices
54
+ idx = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
55
+
56
+ # 2. Generate token by token
57
+ for _ in range(max_new_tokens):
58
+ # Crop context if it exceeds block size
59
+ idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
60
+
61
+ # Get predictions (Forward pass)
62
+ # TorchScript models are called like functions
63
+ logits = model(idx_cond)
64
+
65
+ # Focus on the last token
66
+ logits = logits[:, -1, :]
67
+
68
+ # Apply Softmax to get probabilities
69
+ probs = F.softmax(logits, dim=-1)
70
+
71
+ # Sample from the distribution
72
+ idx_next = torch.multinomial(probs, num_samples=1)
73
+
74
+ # Append the new token to the sequence
75
+ idx = torch.cat((idx, idx_next), dim=1)
76
+
77
+ # 3. Decode indices back to text
78
+ return decode(idx[0].tolist())
79
+
80
+ # ==========================================
81
+ # 4. START CHAT INTERFACE
82
+ # ==========================================
83
+ print("\n" + "="*40)
84
+ print("YAGIZ GPT (FULL PACKAGED) - READY")
85
+ print("Type 'q' and press Enter to exit.")
86
+ print("="*40 + "\n")
87
+
88
+ while True:
89
+ user_input = input("Ask a question: ")
90
+ if user_input.lower() == 'q':
91
+ print("Exiting...")
92
+ break
93
+
94
+ # Prompt Engineering: Guiding the model with English format
95
+ prompt = f"Question: {user_input}\nAnswer:"
96
+
97
+ print(">> Model is thinking...")
98
+ try:
99
+ response = generate_response(prompt)
100
+
101
+ # Post-processing: Extract only the answer part
102
+ # Splitting by 'Answer:' to remove the prompt from the output
103
+ if "Answer:" in response:
104
+ answer_only = response.split("Answer:")[-1].strip()
105
+ else:
106
+ answer_only = response # Fallback if format breaks
107
+
108
+ print(f"\nAnswer: {answer_only}\n")
109
+ print("-" * 30)
110
+
111
+ except Exception as e:
112
+ print(f"An error occurred: {e}")