|
|
import os |
|
|
import pickle |
|
|
import torch |
|
|
from src.model import RippleGPT |
|
|
from src.config import RippleConfig |
|
|
|
|
|
|
|
|
out_dir = 'out' |
|
|
num_samples = 1 |
|
|
max_new_tokens = 200 |
|
|
temperature = 0.8 |
|
|
top_k = 200 |
|
|
device = 'mps' if torch.backends.mps.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
def main(): |
|
|
torch.manual_seed(1337) |
|
|
|
|
|
|
|
|
ckpt_path = os.path.join(out_dir, 'ckpt_best.pt') |
|
|
if not os.path.exists(ckpt_path): |
|
|
ckpt_path = os.path.join(out_dir, 'ckpt.pt') |
|
|
print("⚠️ Aviso: 'ckpt_best.pt' não encontrado, usando o último 'ckpt.pt'") |
|
|
|
|
|
print(f"Loading checkpoint from {ckpt_path}...") |
|
|
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) |
|
|
|
|
|
|
|
|
gptconf = RippleConfig(**checkpoint['model_args']) |
|
|
model = RippleGPT(gptconf) |
|
|
|
|
|
|
|
|
state_dict = checkpoint['model'] |
|
|
unwanted_prefix = '_orig_mod.' |
|
|
for k,v in list(state_dict.items()): |
|
|
if k.startswith(unwanted_prefix): |
|
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
model.eval() |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
meta_path = os.path.join('data', 'meta.pkl') |
|
|
if os.path.exists(meta_path): |
|
|
print(f"Loading meta from {meta_path}...") |
|
|
with open(meta_path, 'rb') as f: |
|
|
meta = pickle.load(f) |
|
|
stoi, itos = meta['stoi'], meta['itos'] |
|
|
|
|
|
|
|
|
unknown_token = stoi.get('?', 0) |
|
|
encode = lambda s: [stoi.get(c, unknown_token) for c in s] |
|
|
decode = lambda l: ''.join([itos[i] for i in l]) |
|
|
else: |
|
|
print("❌ ERRO: meta.pkl não encontrado! Rode prepare_data.py primeiro.") |
|
|
return |
|
|
|
|
|
|
|
|
test_cases = [ |
|
|
|
|
|
|
|
|
{ |
|
|
"domain": "🐍 PYTHON CODING", |
|
|
"prompt": "# Function to calculate factorial\ndef factorial(n):\n if n == 0:\n return 1\n else:\n return" |
|
|
}, |
|
|
|
|
|
|
|
|
|
|
|
{ |
|
|
"domain": "🧮 MATH LOGIC", |
|
|
"prompt": "Q: Solve 2x = 10\nA: x = 5\n\nQ: Solve -5k + 5 = -10\nA:" |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"domain": "📖 TINY STORY", |
|
|
"prompt": "Once upon a time, there was a little frog. The frog liked to jump. One day," |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"domain": "⚔️ LITERATURE", |
|
|
"prompt": "The General looked at the map and shouted," |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
print("\n" + "="*40) |
|
|
print(f"🤖 RIPPLE GPT: MULTI-DOMAIN TEST") |
|
|
print("="*40) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for case in test_cases: |
|
|
prompt = case["prompt"] |
|
|
domain = case["domain"] |
|
|
|
|
|
print(f"\n[{domain}] Prompt: {prompt.strip()}") |
|
|
print("-" * 20) |
|
|
|
|
|
|
|
|
start_ids = encode(prompt) |
|
|
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) |
|
|
|
|
|
|
|
|
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) |
|
|
|
|
|
|
|
|
generated_text = decode(y[0].tolist()) |
|
|
|
|
|
|
|
|
new_content = generated_text[len(prompt):] |
|
|
print(f"{prompt}\033[94m{new_content}\033[0m") |
|
|
print("-" * 40) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |