ruGPT3XL / generate.py
evilfreelancer's picture
Upload generate.py
6efb4d2 verified
#!/usr/bin/env python3
"""Simple text generation demo for the converted ruGPT-3 XL model."""
import argparse
import sys
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def main():
parser = argparse.ArgumentParser(description="ruGPT-3 XL text generation demo")
parser.add_argument(
"--model_path",
type=str,
default="evilfreelancer/ruGPT3XL",
help="Path to the converted model directory",
)
parser.add_argument("--prompt", type=str, default=None, help="Text prompt")
parser.add_argument(
"--max_new_tokens", type=int, default=128, help="Max tokens to generate"
)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top_k", type=int, default=50)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--repetition_penalty", type=float, default=1.2)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
parser.add_argument(
"--dtype",
type=str,
default="float32",
choices=["float16", "float32", "bfloat16"],
)
parser.add_argument(
"--interactive", action="store_true", help="Interactive multi-turn mode"
)
args = parser.parse_args()
dtype_map = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
torch_dtype = dtype_map[args.dtype]
print(f"Loading model from {args.model_path} ...")
print(f"Device: {args.device}, dtype: {args.dtype}")
tokenizer = AutoTokenizer.from_pretrained(
args.model_path, trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
trust_remote_code=True,
dtype=torch_dtype,
).to(args.device)
model.eval()
print(f"Model loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}")
print()
if args.interactive:
run_interactive(model, tokenizer, args)
elif args.prompt:
run_single(model, tokenizer, args.prompt, args)
else:
prompts = [
"Москва - столица",
"Искусственный интеллект - это",
"В далеком космосе",
]
for prompt in prompts:
run_single(model, tokenizer, prompt, args)
print("-" * 60)
def run_single(model, tokenizer, prompt, args):
print(f"Prompt: {prompt}")
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated: {generated}\n")
def run_interactive(model, tokenizer, args):
print("Interactive mode. Type 'quit' to exit.\n")
while True:
try:
prompt = input("You: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nGoodbye!")
break
if prompt.lower() in ("quit", "exit", "q"):
print("Goodbye!")
break
if not prompt:
continue
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
)
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer_marker = "Ответ: "
if answer_marker in full_text:
answer = full_text.split(answer_marker)[-1].strip()
else:
answer = full_text[len(text) :].strip()
print(f"Model: {answer}\n")
if __name__ == "__main__":
main()