File size: 3,635 Bytes
01d8db4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

import torch
import torch.nn.functional as F
import tiktoken
import os

# ==========================================
# SETTINGS
# ==========================================
model_path = "/content/yagiz_gpt_full_packaged.pt"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
block_size = 512  # Context window size of the model

# ==========================================
# 1. LOAD PACKAGED MODEL
# ==========================================
print(f"Device: {device}")

if not os.path.exists(model_path):
    raise FileNotFoundError(f"ERROR: File {model_path} not found. Please make sure the model is packaged correctly.")

print(f"Loading {model_path}...")

# MAGIC PART: No class definitions needed, just loading the TorchScript model.
try:
    model = torch.jit.load(model_path, map_location=device)
    model.eval()
    print("Model loaded successfully!")
except Exception as e:
    print(f"Failed to load the model: {e}")
    exit()

# ==========================================
# 2. TOKENIZER SETUP
# ==========================================
# Using 'tiktoken' since the model was trained with GPT-2 tokenizer (vocab_size=50257)
try:
    enc = tiktoken.get_encoding("gpt2")
except:
    print("Tiktoken library missing. Installing...")
    os.system("pip install tiktoken")
    import tiktoken
    enc = tiktoken.get_encoding("gpt2")

# Helper functions for encoding and decoding
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)

# ==========================================
# 3. RESPONSE GENERATION FUNCTION
# ==========================================
def generate_response(prompt, max_new_tokens=100):
    # 1. Convert text to tensor indices
    idx = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
    
    # 2. Generate token by token
    for _ in range(max_new_tokens):
        # Crop context if it exceeds block size
        idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
        
        # Get predictions (Forward pass)
        # TorchScript models are called like functions
        logits = model(idx_cond)
        
        # Focus on the last token
        logits = logits[:, -1, :]
        
        # Apply Softmax to get probabilities
        probs = F.softmax(logits, dim=-1)
        
        # Sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1)
        
        # Append the new token to the sequence
        idx = torch.cat((idx, idx_next), dim=1)

    # 3. Decode indices back to text
    return decode(idx[0].tolist())

# ==========================================
# 4. START CHAT INTERFACE
# ==========================================
print("\n" + "="*40)
print("YAGIZ GPT (FULL PACKAGED) - READY")
print("Type 'q' and press Enter to exit.")
print("="*40 + "\n")

while True:
    user_input = input("Ask a question: ")
    if user_input.lower() == 'q':
        print("Exiting...")
        break
    
    # Prompt Engineering: Guiding the model with English format
    prompt = f"Question: {user_input}\nAnswer:"
    
    print(">> Model is thinking...")
    try:
        response = generate_response(prompt)
        
        # Post-processing: Extract only the answer part
        # Splitting by 'Answer:' to remove the prompt from the output
        if "Answer:" in response:
            answer_only = response.split("Answer:")[-1].strip()
        else:
            answer_only = response # Fallback if format breaks
        
        print(f"\nAnswer: {answer_only}\n")
        print("-" * 30)
        
    except Exception as e:
        print(f"An error occurred: {e}")