hwding commited on
Commit
5b25c5f
·
verified ·
1 Parent(s): 38f8bec

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +68 -0
handler.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
+ from peft import PeftModel
5
+
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, path: str = ""):
9
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ bnb_config = BitsAndBytesConfig(
12
+ load_in_4bit=True,
13
+ bnb_4bit_quant_type="nf4",
14
+ bnb_4bit_compute_dtype=torch.bfloat16,
15
+ bnb_4bit_use_double_quant=True,
16
+ )
17
+
18
+ base_model_id = "deepseek-ai/deepseek-coder-6.7b-instruct"
19
+
20
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
21
+ self.model = AutoModelForCausalLM.from_pretrained(
22
+ base_model_id,
23
+ quantization_config=bnb_config,
24
+ device_map="auto",
25
+ trust_remote_code=True,
26
+ )
27
+ self.model = PeftModel.from_pretrained(self.model, path)
28
+ self.model.eval()
29
+
30
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
31
+ inputs = data.get("inputs", "")
32
+ parameters = data.get("parameters", {})
33
+
34
+ max_new_tokens = parameters.get("max_new_tokens", 512)
35
+ temperature = parameters.get("temperature", 0.7)
36
+ top_p = parameters.get("top_p", 0.95)
37
+ do_sample = parameters.get("do_sample", True)
38
+
39
+ if not inputs.startswith("### System:"):
40
+ prompt = f"""### System:
41
+ You are an expert Minecraft Forge mod developer for version 1.21.11. Write clean, efficient, and well-structured Java code.
42
+
43
+ ### User:
44
+ {inputs}
45
+
46
+ ### Assistant:
47
+ """
48
+ else:
49
+ prompt = inputs
50
+
51
+ input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
52
+
53
+ with torch.no_grad():
54
+ outputs = self.model.generate(
55
+ **input_ids,
56
+ max_new_tokens=max_new_tokens,
57
+ temperature=temperature,
58
+ top_p=top_p,
59
+ do_sample=do_sample,
60
+ pad_token_id=self.tokenizer.eos_token_id,
61
+ )
62
+
63
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
64
+
65
+ if "### Assistant:" in generated_text:
66
+ generated_text = generated_text.split("### Assistant:")[-1].strip()
67
+
68
+ return {"generated_text": generated_text}