| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from reframr.model import ReframrModel |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Run Reframr-RFM-v1-Base locally.") |
| parser.add_argument("--model", default=str(REPO_ROOT / "model.safetensors")) |
| parser.add_argument("--prompt", default="Who are you, and what makes Reframr different?") |
| parser.add_argument("--system", default="") |
| parser.add_argument("--max-tokens", type=int, default=90) |
| parser.add_argument("--temperature", type=float, default=0.92) |
| parser.add_argument("--top-k", type=int, default=72) |
| parser.add_argument("--top-p", type=float, default=0.92) |
| parser.add_argument("--repetition-penalty", type=float, default=1.18) |
| args = parser.parse_args() |
|
|
| context = args.prompt |
| if args.system.strip(): |
| context = f"System instruction: {args.system.strip()}\nUser: {args.prompt}" |
|
|
| model = ReframrModel.load(args.model) |
| print( |
| model.generate_text( |
| context, |
| max_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| ) |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|