kieraisverybored commited on
Commit
0709cbe
·
verified ·
1 Parent(s): a24b86f

Create infer.py

Browse files
Files changed (1) hide show
  1. infer.py +93 -0
infer.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ infer.py – chat with fein
4
+
5
+ Usage:
6
+ python chat_fein.py # load from HF repo
7
+ python chat_fein.py --model . # load from local folder
8
+ """
9
+ import os, sys, argparse, torch, readline
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
11
+
12
+ # ----------------------------------------------------------------------
13
+ # 1. CLI args
14
+ # ----------------------------------------------------------------------
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ "--model",
18
+ default="kieraisverybored/fein", # default = Hub repo
19
+ help="HF repo ID *or* path to a local model folder",
20
+ )
21
+ parser.add_argument("--load-8bit", action="store_true",
22
+ help="Load in 8-bit (else 4-bit)")
23
+ args = parser.parse_args()
24
+
25
+ MODEL_ID = args.model
26
+ SYSTEM_MSG = "You are a helpful assistant. You are the 'fein 14b' model by kieradev, a 14b LLM fine tuned from Qwen3."
27
+
28
+ # ----------------------------------------------------------------------
29
+ # 2. Load tokenizer & model
30
+ # ----------------------------------------------------------------------
31
+ print(f"Loading model from: {MODEL_ID}")
32
+ dtype = torch.bfloat16 # or torch.float16 if your GPU prefers
33
+ bnb_cfg = BitsAndBytesConfig(
34
+ load_in_4bit=not args.load_8bit,
35
+ load_in_8bit=args.load_8bit,
36
+ bnb_4bit_compute_dtype=dtype,
37
+ )
38
+
39
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ MODEL_ID,
42
+ device_map="auto",
43
+ torch_dtype=dtype,
44
+ quantization_config=bnb_cfg,
45
+ )
46
+ model.eval()
47
+ if tokenizer.pad_token is None:
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+
50
+ # ----------------------------------------------------------------------
51
+ # 3. Prompt builder, chat loop
52
+ # ----------------------------------------------------------------------
53
+ T_START, T_END = "<|im_start|>", "<|im_end|>"
54
+
55
+ def build_prompt(history, user_msg):
56
+ prompt = f"{T_START}system\n{SYSTEM_MSG}{T_END}\n"
57
+ for u, a in history:
58
+ prompt += f"{T_START}user\n{u}{T_END}\n"
59
+ prompt += f"{T_START}assistant\n{a}{T_END}\n"
60
+ prompt += f"{T_START}user\n{user_msg}{T_END}\n"
61
+ prompt += f"{T_START}assistant\n"
62
+ return prompt
63
+
64
+ history = []
65
+ print("\nChat ready! Type 'exit' or Ctrl-C to quit.\n")
66
+ while True:
67
+ try:
68
+ user_in = input("User: ").strip()
69
+ except (KeyboardInterrupt, EOFError):
70
+ print("\nBye.")
71
+ break
72
+ if user_in.lower() in {"exit", "quit"}:
73
+ break
74
+ if not user_in:
75
+ continue
76
+
77
+ prompt = build_prompt(history, user_in)
78
+ input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
79
+
80
+ gen_ids = model.generate(
81
+ **input_ids,
82
+ max_new_tokens=1024,
83
+ do_sample=True,
84
+ temperature=0.7,
85
+ top_p=0.95,
86
+ pad_token_id=tokenizer.eos_token_id,
87
+ )
88
+
89
+ full = tokenizer.decode(gen_ids[0], skip_special_tokens=False)
90
+ answer = full.split(f"{T_START}assistant\n")[-1].split(T_END)[0].strip()
91
+
92
+ print(f"Assistant: {answer}\n")
93
+ history.append((user_in, answer))