revana commited on
Commit
4162d87
Β·
verified Β·
1 Parent(s): 5587a84

Upload infer.py

Browse files
Files changed (1) hide show
  1. infer.py +176 -0
infer.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """fingpt β€” inference with a LoRA adapter.
3
+
4
+ Loads the base model from HuggingFace Hub, injects LoRA layers using the
5
+ metadata stored in the adapter checkpoint, then runs generation.
6
+
7
+ Usage
8
+ -----
9
+ # Interactive REPL
10
+ python infer.py --adapter weights_lora_coder_1b5/adapter_final.pt
11
+
12
+ # Single prompt
13
+ python infer.py --adapter weights_lora_coder_1b5/adapter_final.pt \
14
+ --prompt "Fix this Python code: ..."
15
+
16
+ # One-liner (pipe-friendly)
17
+ echo "Fix: def f(n): return n * f(n)" | python infer.py \
18
+ --adapter weights_lora_coder_1b5/adapter_final.pt
19
+ """
20
+
21
+ import argparse
22
+ import sys
23
+ from pathlib import Path
24
+
25
+ import torch
26
+
27
+ _HERE = Path(__file__).resolve().parent
28
+ sys.path.insert(0, str(_HERE))
29
+
30
+ from fingpt.lora import inject_lora
31
+
32
+
33
+ # ── Model loading ─────────────────────────────────────────────────────────────
34
+
35
+ def load_model(adapter_path: str):
36
+ """Load base model + inject LoRA + load adapter weights.
37
+
38
+ All config is read from the adapter checkpoint metadata so you never
39
+ need to pass model name / r / alpha manually.
40
+ """
41
+ from transformers import AutoModelForCausalLM, AutoTokenizer
42
+
43
+ ckpt = torch.load(adapter_path, map_location="cpu", weights_only=False)
44
+ meta = ckpt["meta"]
45
+ state_dict = ckpt["state_dict"]
46
+
47
+ model_name = meta["model_name"]
48
+ lora_r = meta["lora_r"]
49
+ lora_alpha = meta["lora_alpha"]
50
+ lora_targets = meta["lora_target_modules"]
51
+
52
+ print(f"[infer] base={model_name} r={lora_r} Ξ±={lora_alpha}")
53
+ print(f"[infer] targets={lora_targets}")
54
+
55
+ # Tokenizer
56
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
57
+ if tokenizer.pad_token is None:
58
+ tokenizer.pad_token = tokenizer.eos_token
59
+
60
+ # Base model
61
+ cuda_ok = torch.cuda.is_available()
62
+ try:
63
+ import accelerate # noqa: F401
64
+ load_kwargs = {"device_map": "auto"} if cuda_ok else {}
65
+ except ImportError:
66
+ load_kwargs = {}
67
+
68
+ model = AutoModelForCausalLM.from_pretrained(
69
+ model_name,
70
+ dtype=torch.bfloat16 if cuda_ok else torch.float32,
71
+ trust_remote_code=True,
72
+ **load_kwargs,
73
+ )
74
+ if not load_kwargs:
75
+ device = torch.device("cuda" if cuda_ok else "cpu")
76
+ model = model.to(device)
77
+
78
+ # Inject LoRA (dropout=0 at inference β€” no regularisation needed)
79
+ model = inject_lora(model, target_modules=lora_targets,
80
+ r=lora_r, alpha=lora_alpha, dropout=0.0)
81
+
82
+ # Load trained adapter weights
83
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
84
+ lora_missing = [k for k in missing if "lora" in k]
85
+ if lora_missing:
86
+ raise ValueError(f"Missing LoRA keys: {lora_missing}")
87
+ print(f"[infer] Loaded {len(state_dict)} adapter tensors from {adapter_path}")
88
+
89
+ model.eval()
90
+ return model, tokenizer
91
+
92
+
93
+ # ── Generation ────────────────────────────────────────────────────────────────
94
+
95
+ def generate(
96
+ model,
97
+ tokenizer,
98
+ prompt: str,
99
+ max_new_tokens: int = 512,
100
+ temperature: float = 0.1,
101
+ ) -> str:
102
+ """Format prompt as ChatML and generate a response."""
103
+ messages = [{"role": "user", "content": prompt}]
104
+ text = tokenizer.apply_chat_template(
105
+ messages, tokenize=False, add_generation_prompt=True
106
+ )
107
+
108
+ device = next(model.parameters()).device
109
+ inputs = tokenizer(text, return_tensors="pt").to(device)
110
+
111
+ with torch.no_grad():
112
+ outputs = model.generate(
113
+ **inputs,
114
+ max_new_tokens=max_new_tokens,
115
+ do_sample=temperature > 0,
116
+ temperature=temperature if temperature > 0 else 1.0,
117
+ pad_token_id=tokenizer.pad_token_id,
118
+ eos_token_id=tokenizer.eos_token_id,
119
+ )
120
+
121
+ new_ids = outputs[0][inputs["input_ids"].shape[1]:]
122
+ return tokenizer.decode(new_ids, skip_special_tokens=True)
123
+
124
+
125
+ # ── CLI ───────────────────────────────────────────────────────────────────────
126
+
127
+ def main() -> None:
128
+ parser = argparse.ArgumentParser(
129
+ description="fingpt LoRA inference",
130
+ formatter_class=argparse.RawDescriptionHelpFormatter,
131
+ epilog=__doc__,
132
+ )
133
+ parser.add_argument("--adapter", required=True,
134
+ help="Path to adapter .pt file")
135
+ parser.add_argument("--prompt", default=None,
136
+ help="Single prompt string (omit for interactive REPL)")
137
+ parser.add_argument("--max-new-tokens", type=int, default=512)
138
+ parser.add_argument("--temperature", type=float, default=0.1,
139
+ help="0 = greedy, >0 = sampling")
140
+ args = parser.parse_args()
141
+
142
+ model, tokenizer = load_model(args.adapter)
143
+
144
+ if args.prompt:
145
+ print(generate(model, tokenizer, args.prompt,
146
+ args.max_new_tokens, args.temperature))
147
+ return
148
+
149
+ # Check stdin (pipe mode)
150
+ if not sys.stdin.isatty():
151
+ prompt = sys.stdin.read().strip()
152
+ if prompt:
153
+ print(generate(model, tokenizer, prompt,
154
+ args.max_new_tokens, args.temperature))
155
+ return
156
+
157
+ # Interactive REPL
158
+ print("[infer] Interactive mode β€” type 'quit' or Ctrl-D to exit.\n")
159
+ while True:
160
+ try:
161
+ prompt = input(">>> ").strip()
162
+ except (EOFError, KeyboardInterrupt):
163
+ print()
164
+ break
165
+ if not prompt:
166
+ continue
167
+ if prompt.lower() in ("quit", "exit", "q"):
168
+ break
169
+ print()
170
+ print(generate(model, tokenizer, prompt,
171
+ args.max_new_tokens, args.temperature))
172
+ print()
173
+
174
+
175
+ if __name__ == "__main__":
176
+ main()