| |
|
| | """Zenith-7B Inference Script for Standard GPUs"""
|
| |
|
| | import torch
|
| | import argparse
|
| | from pathlib import Path
|
| | from typing import Optional, Dict, Any
|
| |
|
| |
|
| | import sys
|
| | sys.path.append(str(Path(__file__).parent))
|
| |
|
| | from configs.zenith_config import get_7b_config
|
| | from models.zenith_model import ZenithForCausalLM
|
| | from data.advanced_tokenizer import AdvancedTokenizer
|
| |
|
| |
|
| | def load_model(checkpoint_path: str, device: str = "cuda"):
|
| | """Load trained model from checkpoint."""
|
| | config = get_7b_config()
|
| |
|
| |
|
| | tokenizer = AdvancedTokenizer.from_pretrained(checkpoint_path)
|
| | config.vocab_size = tokenizer.get_vocab_size()
|
| |
|
| |
|
| | model = ZenithForCausalLM.from_pretrained(
|
| | checkpoint_path,
|
| | config=config,
|
| | device_map="auto" if device == "cuda" else None
|
| | )
|
| | model.eval()
|
| |
|
| | return model, tokenizer
|
| |
|
| |
|
| | def generate(
|
| | model: ZenithForCausalLM,
|
| | tokenizer: AdvancedTokenizer,
|
| | prompt: str,
|
| | max_new_tokens: int = 512,
|
| | temperature: float = 0.7,
|
| | top_p: float = 0.9,
|
| | top_k: int = 50,
|
| | repetition_penalty: float = 1.1,
|
| | do_sample: bool = True,
|
| | stream: bool = False
|
| | ):
|
| | """Generate text from the model."""
|
| | input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
|
| |
|
| | with torch.no_grad():
|
| | if stream:
|
| |
|
| | from transformers import TextIteratorStreamer
|
| | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| | generation_kwargs = dict(
|
| | input_ids=input_ids,
|
| | max_new_tokens=max_new_tokens,
|
| | temperature=temperature,
|
| | top_p=top_p,
|
| | top_k=top_k,
|
| | repetition_penalty=repetition_penalty,
|
| | do_sample=do_sample,
|
| | streamer=streamer
|
| | )
|
| | from threading import Thread
|
| | thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| | thread.start()
|
| | return streamer
|
| | else:
|
| | outputs = model.generate(
|
| | input_ids=input_ids,
|
| | max_new_tokens=max_new_tokens,
|
| | temperature=temperature,
|
| | top_p=top_p,
|
| | top_k=top_k,
|
| | repetition_penalty=repetition_penalty,
|
| | do_sample=do_sample,
|
| | pad_token_id=tokenizer.pad_token_id,
|
| | eos_token_id=tokenizer.eos_token_id
|
| | )
|
| | return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| |
|
| |
|
| | def interactive_mode(model, tokenizer):
|
| | """Run interactive chat session."""
|
| | print("=" * 60)
|
| | print("Zenith-7B Interactive Mode")
|
| | print("Type 'quit' to exit, 'clear' to clear history")
|
| | print("=" * 60)
|
| |
|
| | history = []
|
| | while True:
|
| | try:
|
| | user_input = input("\nYou: ").strip()
|
| | if user_input.lower() == 'quit':
|
| | break
|
| | if user_input.lower() == 'clear':
|
| | history = []
|
| | print("History cleared.")
|
| | continue
|
| |
|
| |
|
| | prompt = ""
|
| | for user_msg, assistant_msg in history[-4:]:
|
| | prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
|
| | prompt += f"User: {user_input}\nAssistant:"
|
| |
|
| | print("\nZenith: ", end="", flush=True)
|
| | response = generate(model, tokenizer, prompt, stream=True)
|
| | full_response = ""
|
| | for token in response:
|
| | print(token, end="", flush=True)
|
| | full_response += token
|
| | print()
|
| |
|
| | history.append((user_input, full_response))
|
| |
|
| | except KeyboardInterrupt:
|
| | print("\n\nInterrupted. Type 'quit' to exit.")
|
| | except Exception as e:
|
| | print(f"\nError: {e}")
|
| |
|
| |
|
| | def main():
|
| | parser = argparse.ArgumentParser(description="Zenith-7B Inference")
|
| | parser.add_argument(
|
| | "--checkpoint",
|
| | type=str,
|
| | required=True,
|
| | help="Path to model checkpoint directory"
|
| | )
|
| | parser.add_argument(
|
| | "--prompt",
|
| | type=str,
|
| | default=None,
|
| | help="Prompt for generation (if not provided, enters interactive mode)"
|
| | )
|
| | parser.add_argument(
|
| | "--max_new_tokens",
|
| | type=int,
|
| | default=512,
|
| | help="Maximum new tokens to generate"
|
| | )
|
| | parser.add_argument(
|
| | "--temperature",
|
| | type=float,
|
| | default=0.7,
|
| | help="Sampling temperature"
|
| | )
|
| | parser.add_argument(
|
| | "--top_p",
|
| | type=float,
|
| | default=0.9,
|
| | help="Top-p (nucleus) sampling"
|
| | )
|
| | parser.add_argument(
|
| | "--top_k",
|
| | type=int,
|
| | default=50,
|
| | help="Top-k sampling"
|
| | )
|
| | parser.add_argument(
|
| | "--device",
|
| | type=str,
|
| | default="cuda",
|
| | choices=["cuda", "cpu"],
|
| | help="Device to run inference on"
|
| | )
|
| | parser.add_argument(
|
| | "--stream",
|
| | action="store_true",
|
| | help="Stream output token by token"
|
| | )
|
| |
|
| | args = parser.parse_args()
|
| |
|
| |
|
| | print(f"Loading model from {args.checkpoint}...")
|
| | model, tokenizer = load_model(args.checkpoint, args.device)
|
| | print("Model loaded successfully!")
|
| |
|
| | if args.prompt:
|
| |
|
| | response = generate(
|
| | model, tokenizer, args.prompt,
|
| | max_new_tokens=args.max_new_tokens,
|
| | temperature=args.temperature,
|
| | top_p=args.top_p,
|
| | top_k=args.top_k,
|
| | stream=args.stream
|
| | )
|
| | if args.stream:
|
| | for token in response:
|
| | print(token, end="", flush=True)
|
| | print()
|
| | else:
|
| | print(f"\nResponse: {response}")
|
| | else:
|
| |
|
| | interactive_mode(model, tokenizer)
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main() |