DarshanScripts commited on
Commit
bd46c34
·
verified ·
1 Parent(s): f6fcf0c

Upload stratego/models/hf_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. stratego/models/hf_model.py +33 -0
stratego/models/hf_model.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from .base import AgentLike
4
+ from ..utils.parsing import extract_legal_moves, extract_forbidden, slice_board_and_moves, MOVE_RE
5
+ from ..prompts import get_prompt_pack
6
+
7
+ class HFLocalAgent(AgentLike):
8
+ def __init__(self, model_id: str, prompt_pack: str="base", **gen):
9
+ self.model_name = f"hf:{model_id}"
10
+ self.pack = get_prompt_pack(prompt_pack)
11
+ self.tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
12
+ self.model = AutoModelForCausalLM.from_pretrained(
13
+ model_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
14
+ device_map="auto"
15
+ )
16
+ self.gen = dict(max_new_tokens=32, do_sample=True, temperature=0.1, top_p=0.9, **gen)
17
+
18
+ def __call__(self, observation: str) -> str:
19
+ legal = extract_legal_moves(observation)
20
+ if not legal: return ""
21
+ forb = set(extract_forbidden(observation))
22
+ legal_filtered = [m for m in legal if m not in forb] or legal
23
+
24
+ sys = self.pack.system
25
+ user = self.pack.guidance(slice_board_and_moves(observation))
26
+ prompt = f"{sys}\n\n{user}"
27
+
28
+ inputs = self.tok(prompt, return_tensors="pt").to(self.model.device)
29
+ with torch.no_grad():
30
+ out = self.model.generate(**inputs, **self.gen)
31
+ text = self.tok.decode(out[0], skip_special_tokens=True)
32
+ m = MOVE_RE.search(text[len(prompt):])
33
+ return m.group(0) if m and m.group(0) in legal_filtered else legal_filtered[0]