Nano_Ethos50M / run_model.py
Ekrem-the-second's picture
Upload run_model.py via Colab
01d8db4 verified
import torch
import torch.nn.functional as F
import tiktoken
import os
# ==========================================
# SETTINGS
# ==========================================
model_path = "/content/yagiz_gpt_full_packaged.pt"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
block_size = 512 # Context window size of the model
# ==========================================
# 1. LOAD PACKAGED MODEL
# ==========================================
print(f"Device: {device}")
if not os.path.exists(model_path):
raise FileNotFoundError(f"ERROR: File {model_path} not found. Please make sure the model is packaged correctly.")
print(f"Loading {model_path}...")
# MAGIC PART: No class definitions needed, just loading the TorchScript model.
try:
model = torch.jit.load(model_path, map_location=device)
model.eval()
print("Model loaded successfully!")
except Exception as e:
print(f"Failed to load the model: {e}")
exit()
# ==========================================
# 2. TOKENIZER SETUP
# ==========================================
# Using 'tiktoken' since the model was trained with GPT-2 tokenizer (vocab_size=50257)
try:
enc = tiktoken.get_encoding("gpt2")
except:
print("Tiktoken library missing. Installing...")
os.system("pip install tiktoken")
import tiktoken
enc = tiktoken.get_encoding("gpt2")
# Helper functions for encoding and decoding
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
# ==========================================
# 3. RESPONSE GENERATION FUNCTION
# ==========================================
def generate_response(prompt, max_new_tokens=100):
# 1. Convert text to tensor indices
idx = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
# 2. Generate token by token
for _ in range(max_new_tokens):
# Crop context if it exceeds block size
idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
# Get predictions (Forward pass)
# TorchScript models are called like functions
logits = model(idx_cond)
# Focus on the last token
logits = logits[:, -1, :]
# Apply Softmax to get probabilities
probs = F.softmax(logits, dim=-1)
# Sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
# Append the new token to the sequence
idx = torch.cat((idx, idx_next), dim=1)
# 3. Decode indices back to text
return decode(idx[0].tolist())
# ==========================================
# 4. START CHAT INTERFACE
# ==========================================
print("\n" + "="*40)
print("YAGIZ GPT (FULL PACKAGED) - READY")
print("Type 'q' and press Enter to exit.")
print("="*40 + "\n")
while True:
user_input = input("Ask a question: ")
if user_input.lower() == 'q':
print("Exiting...")
break
# Prompt Engineering: Guiding the model with English format
prompt = f"Question: {user_input}\nAnswer:"
print(">> Model is thinking...")
try:
response = generate_response(prompt)
# Post-processing: Extract only the answer part
# Splitting by 'Answer:' to remove the prompt from the output
if "Answer:" in response:
answer_only = response.split("Answer:")[-1].strip()
else:
answer_only = response # Fallback if format breaks
print(f"\nAnswer: {answer_only}\n")
print("-" * 30)
except Exception as e:
print(f"An error occurred: {e}")