| |
| """Simple text generation demo for the converted ruGPT-3 XL model.""" |
|
|
| import argparse |
| import sys |
| import os |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="ruGPT-3 XL text generation demo") |
| parser.add_argument( |
| "--model_path", |
| type=str, |
| default="evilfreelancer/ruGPT3XL", |
| help="Path to the converted model directory", |
| ) |
| parser.add_argument("--prompt", type=str, default=None, help="Text prompt") |
| parser.add_argument( |
| "--max_new_tokens", type=int, default=128, help="Max tokens to generate" |
| ) |
| parser.add_argument("--temperature", type=float, default=0.7) |
| parser.add_argument("--top_k", type=int, default=50) |
| parser.add_argument("--top_p", type=float, default=0.9) |
| parser.add_argument("--repetition_penalty", type=float, default=1.2) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="cuda" if torch.cuda.is_available() else "cpu", |
| ) |
| parser.add_argument( |
| "--dtype", |
| type=str, |
| default="float32", |
| choices=["float16", "float32", "bfloat16"], |
| ) |
| parser.add_argument( |
| "--interactive", action="store_true", help="Interactive multi-turn mode" |
| ) |
| args = parser.parse_args() |
|
|
| dtype_map = { |
| "float16": torch.float16, |
| "float32": torch.float32, |
| "bfloat16": torch.bfloat16, |
| } |
| torch_dtype = dtype_map[args.dtype] |
|
|
| print(f"Loading model from {args.model_path} ...") |
| print(f"Device: {args.device}, dtype: {args.dtype}") |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| args.model_path, trust_remote_code=True |
| ) |
| model = AutoModelForCausalLM.from_pretrained( |
| args.model_path, |
| trust_remote_code=True, |
| dtype=torch_dtype, |
| ).to(args.device) |
| model.eval() |
|
|
| print(f"Model loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}") |
| print() |
|
|
| if args.interactive: |
| run_interactive(model, tokenizer, args) |
| elif args.prompt: |
| run_single(model, tokenizer, args.prompt, args) |
| else: |
| prompts = [ |
| "Москва - столица", |
| "Искусственный интеллект - это", |
| "В далеком космосе", |
| ] |
| for prompt in prompts: |
| run_single(model, tokenizer, prompt, args) |
| print("-" * 60) |
|
|
|
|
| def run_single(model, tokenizer, prompt, args): |
| print(f"Prompt: {prompt}") |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| do_sample=True, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
|
|
| generated = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| print(f"Generated: {generated}\n") |
|
|
|
|
| def run_interactive(model, tokenizer, args): |
| print("Interactive mode. Type 'quit' to exit.\n") |
| while True: |
| try: |
| prompt = input("You: ").strip() |
| except (EOFError, KeyboardInterrupt): |
| print("\nGoodbye!") |
| break |
|
|
| if prompt.lower() in ("quit", "exit", "q"): |
| print("Goodbye!") |
| break |
| if not prompt: |
| continue |
|
|
| messages = [{"role": "user", "content": prompt}] |
| text = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
|
|
| inputs = tokenizer(text, return_tensors="pt").to(model.device) |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| do_sample=True, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
|
|
| full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| answer_marker = "Ответ: " |
| if answer_marker in full_text: |
| answer = full_text.split(answer_marker)[-1].strip() |
| else: |
| answer = full_text[len(text) :].strip() |
|
|
| print(f"Model: {answer}\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|