Manthan-T1 / scripts /infer_hf.py
Atah Alam
Manthan-T1 clean code-only
7f7a72e
from __future__ import annotations
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--model", type=str, default="./")
ap.add_argument("--prompt", type=str, required=True)
ap.add_argument("--image", type=str, default=None, help="URL or local path")
args = ap.parse_args()
model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=True)
tok = AutoTokenizer.from_pretrained(args.model, use_fast=False)
if hasattr(model, "chat"):
text = model.chat(prompt=args.prompt, image=args.image, tokenizer=tok)
print(text)
return
inputs = tok(args.prompt, return_tensors="pt")
out = model.generate(**inputs, max_new_tokens=128)
print(tok.decode(out[0], skip_special_tokens=True))
if __name__ == "__main__":
main()