| from __future__ import annotations | |
| import argparse | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| def main() -> None: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--model", type=str, default="./") | |
| ap.add_argument("--prompt", type=str, required=True) | |
| ap.add_argument("--image", type=str, default=None, help="URL or local path") | |
| args = ap.parse_args() | |
| model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=True) | |
| tok = AutoTokenizer.from_pretrained(args.model, use_fast=False) | |
| if hasattr(model, "chat"): | |
| text = model.chat(prompt=args.prompt, image=args.image, tokenizer=tok) | |
| print(text) | |
| return | |
| inputs = tok(args.prompt, return_tensors="pt") | |
| out = model.generate(**inputs, max_new_tokens=128) | |
| print(tok.decode(out[0], skip_special_tokens=True)) | |
| if __name__ == "__main__": | |
| main() | |