import torch from transformers import PreTrainedTokenizerFast, LlamaForCausalLM MODEL_DIR = "sapbot/toyllama-13m" # --- Generation Settings --- MAX_NEW_TOKENS = 150 # Increased slightly so it can finish thoughts TEMPERATURE = 0.7 TOP_P = 0.9 def main(): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Running inference on: {device.upper()}") try: print(f"Loading model from {MODEL_DIR}...") tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_DIR) model = LlamaForCausalLM.from_pretrained(MODEL_DIR) model.to(device) model.eval() print("Model loaded successfully!\n") except Exception as e: print(f"Failed to load. Error: {e}") return print("=" * 60) print("INTERACTIVE MODE: Ready! (Type 'quit' or 'exit' to stop)") print("=" * 60) # Get special tokens to ignore during decoding special_tokens = {tokenizer.bos_token, tokenizer.eos_token, tokenizer.pad_token, None} while True: try: prompt = input(f"\n>>> Enter prompt (temp: {TEMPERATURE}): ") except (KeyboardInterrupt, EOFError): print("\nExiting...") break if prompt.strip().lower() in ["quit", "exit"]: print("Goodbye!") break if not prompt.strip(): continue inputs = tokenizer(prompt, return_tensors="pt").to(device) inputs.pop("token_type_ids", None) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, top_p=TOP_P, do_sample=True, pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id ) # ------------------------------------------------------------------- # THE NEW FIX: Manual Subword Gluing # ------------------------------------------------------------------- # 1. Get the raw list of token string pieces from the model output token_ids = outputs[0].tolist() raw_tokens = tokenizer.convert_ids_to_tokens(token_ids) # 2. Filter out special tokens like or clean_tokens = [tok for tok in raw_tokens if tok not in special_tokens] # 3. Glue them together tightly with NO spaces (fixes "M iddle" -> "Middle") raw_text = "".join(clean_tokens) # 4. NOW convert the ByteLevel characters into real spaces and newlines generated_text = raw_text.replace("Ġ", " ").replace("Ċ", "\n") # ------------------------------------------------------------------- print("-" * 60) print(generated_text.strip()) print("-" * 60) if __name__ == "__main__": main()