#!/usr/bin/env python3 """Zenith-7B Inference Script for Standard GPUs""" import torch import argparse from pathlib import Path from typing import Optional, Dict, Any # Add current directory to path for imports import sys sys.path.append(str(Path(__file__).parent)) from configs.zenith_config import get_7b_config from models.zenith_model import ZenithForCausalLM from data.advanced_tokenizer import AdvancedTokenizer def load_model(checkpoint_path: str, device: str = "cuda"): """Load trained model from checkpoint.""" config = get_7b_config() # Initialize tokenizer tokenizer = AdvancedTokenizer.from_pretrained(checkpoint_path) config.vocab_size = tokenizer.get_vocab_size() # Load model model = ZenithForCausalLM.from_pretrained( checkpoint_path, config=config, device_map="auto" if device == "cuda" else None ) model.eval() return model, tokenizer def generate( model: ZenithForCausalLM, tokenizer: AdvancedTokenizer, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.1, do_sample: bool = True, stream: bool = False ): """Generate text from the model.""" input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): if stream: # Streaming generation from transformers import TextIteratorStreamer streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, do_sample=do_sample, streamer=streamer ) from threading import Thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() return streamer else: outputs = model.generate( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, do_sample=do_sample, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) return tokenizer.decode(outputs[0], skip_special_tokens=True) def interactive_mode(model, tokenizer): """Run interactive chat session.""" print("=" * 60) print("Zenith-7B Interactive Mode") print("Type 'quit' to exit, 'clear' to clear history") print("=" * 60) history = [] while True: try: user_input = input("\nYou: ").strip() if user_input.lower() == 'quit': break if user_input.lower() == 'clear': history = [] print("History cleared.") continue # Build prompt with history prompt = "" for user_msg, assistant_msg in history[-4:]: # Keep last 4 exchanges prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n" prompt += f"User: {user_input}\nAssistant:" print("\nZenith: ", end="", flush=True) response = generate(model, tokenizer, prompt, stream=True) full_response = "" for token in response: print(token, end="", flush=True) full_response += token print() history.append((user_input, full_response)) except KeyboardInterrupt: print("\n\nInterrupted. Type 'quit' to exit.") except Exception as e: print(f"\nError: {e}") def main(): parser = argparse.ArgumentParser(description="Zenith-7B Inference") parser.add_argument( "--checkpoint", type=str, required=True, help="Path to model checkpoint directory" ) parser.add_argument( "--prompt", type=str, default=None, help="Prompt for generation (if not provided, enters interactive mode)" ) parser.add_argument( "--max_new_tokens", type=int, default=512, help="Maximum new tokens to generate" ) parser.add_argument( "--temperature", type=float, default=0.7, help="Sampling temperature" ) parser.add_argument( "--top_p", type=float, default=0.9, help="Top-p (nucleus) sampling" ) parser.add_argument( "--top_k", type=int, default=50, help="Top-k sampling" ) parser.add_argument( "--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device to run inference on" ) parser.add_argument( "--stream", action="store_true", help="Stream output token by token" ) args = parser.parse_args() # Load model print(f"Loading model from {args.checkpoint}...") model, tokenizer = load_model(args.checkpoint, args.device) print("Model loaded successfully!") if args.prompt: # Single generation response = generate( model, tokenizer, args.prompt, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, stream=args.stream ) if args.stream: for token in response: print(token, end="", flush=True) print() else: print(f"\nResponse: {response}") else: # Interactive mode interactive_mode(model, tokenizer) if __name__ == "__main__": main()