import os import torch import torch.nn.functional as F from collections import OrderedDict import string import sys from model import ChatGCLM, MAX_SEQ_LEN # --- Configuration --- EOS_ID = 2 OFFSET = 3 CHARS = string.printable def get_model_path(): """Finds the first model file starting with Turing_ in the current directory.""" for f in os.listdir("."): if f.startswith("Turing_") and f.endswith(".pt"): return f return None MODEL_PATH = get_model_path() if MODEL_PATH is None: print("Error: No model checkpoint found!") print("Please train the model first with: python3 train.py") sys.exit(1) # --- Helper Functions --- def encode(text): return [CHARS.index(c) + OFFSET for c in text if c in CHARS] def decode(ids): return "".join([CHARS[i - OFFSET] for i in ids if i >= OFFSET]) def load_model(device): vocab_size = len(CHARS) + OFFSET model = ChatGCLM(vocab_size).to(device) if os.path.exists(MODEL_PATH) and os.path.getsize(MODEL_PATH) > 0: print(f"Loading model from: {MODEL_PATH}") ckpt = torch.load(MODEL_PATH, map_location=device) if isinstance(ckpt, dict): if 'model_state_dict' in ckpt: state_dict = ckpt['model_state_dict'] elif 'state_dict' in ckpt: state_dict = ckpt['state_dict'] else: state_dict = ckpt else: state_dict = ckpt # Handle compilation prefix if present def _strip_module_prefix(sd): keys = list(sd.keys()) if any(k.startswith('module.') for k in keys): new_sd = OrderedDict() for k, v in sd.items(): new_key = k[len('module.'): ] if k.startswith('module.') else k new_sd[new_key] = v return new_sd return sd state_dict = _strip_module_prefix(state_dict) res = model.load_state_dict(state_dict, strict=False) missing = getattr(res, 'missing_keys', None) unexpected = getattr(res, 'unexpected_keys', None) if missing: print(f"Warning: missing keys when loading state_dict: {missing}") if unexpected: print(f"Warning: unexpected keys in state_dict: {unexpected}") model.eval() return model else: print(f"Error: Could not load model from {MODEL_PATH}") return None @torch.no_grad() def generate_stream(model, prompt, device, max_new_tokens=500, temperature=0.7, top_k=50): """ Generates text from the model and streams it to stdout. Returns the full generated text. """ model.eval() input_ids = encode(prompt) x = torch.tensor([input_ids], dtype=torch.long, device=device) # We don't print the prompt again, we just stream the new tokens generated_ids = [] for _ in range(max_new_tokens): # Crop context if needed ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x logits = model(ctx) next_token_logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))) next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) idx = next_token.item() if idx == EOS_ID: break x = torch.cat((x, next_token), dim=1) generated_ids.append(idx) token_text = decode([idx]) print(token_text, end="", flush=True) if len(generated_ids) >= 3 and decode(generated_ids[-3:]) == "": print('\b\b\b \b\b\b', end="", flush=True) break return decode(generated_ids) def main(): device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using device: {device}") model = load_model(device) if model is None: sys.exit(1) print("\n" + "="*50) print("Turing | Chat Interface") print(f"Model: {MODEL_PATH}") print("Type 'quit', 'exit', or 'q' to end the session.") print("="*50 + "\n") history = "" while True: try: # Get user input user_input = input("\n\033[1;36mYou:\033[0m ") # Cyan color for "You:" if user_input.strip().lower() in ['quit', 'exit', 'q']: print("\nGoodbye!") break if not user_input.strip(): continue print("\033[1;32mModel:\033[0m ", end="", flush=True) # Green color for "Model:" # Since this is a raw completion model, we might want to feed it the input directly # and let it continue. # Prepare the prompt with history current_turn = f" {user_input} " full_prompt = history + current_turn # Generate response response = generate_stream(model, full_prompt, device=device) # Update history # We strip from the end if it was generated as a stop token cleaned_response = response if cleaned_response.endswith(""): cleaned_response = cleaned_response[:-3] history += current_turn + cleaned_response print() # Newline after generation except KeyboardInterrupt: print("\n\nExiting...") break except Exception as e: print(f"\nError: {e}") if __name__ == "__main__": main()