GitHub Actions commited on
Commit
3b4cee3
·
1 Parent(s): a7086fd

Sync from GitHub Actions

Browse files
Files changed (1) hide show
  1. server.py +56 -0
server.py CHANGED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from codeInsight.utils.config import load_config
4
+ import litserve as ls
5
+
6
+ class LLMApi(ls.LitAPI):
7
+ def setup(self, device, config_path="config/model.yaml"):
8
+ self.config = load_config(config_path)
9
+ self.dataset_config = self.config['dataset']
10
+ model_name = self.config['paths']['final_model_repo']
11
+
12
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
14
+ if device != "cpu":
15
+ self.model.to(device)
16
+ self.model.eval()
17
+
18
+ def _formet_prompt(self, prompt : str) -> str:
19
+ return f"{self.dataset_config['SYSTEM_PROMPT']}{self.dataset_config['USER_TOKEN']}{prompt}{self.dataset_config['END_TOKEN']}\n\n{self.dataset_config['ASSISTANT_TOKEN']}"
20
+
21
+ def generate(self, prompt : str, max_length : int = 512, temperature: float = 0.2, top_p : float =0.80) -> str:
22
+ try:
23
+ input_text = self._formet_prompt(prompt)
24
+ inputs = self.tokenizer(
25
+ input_text,
26
+ return_tensors="pt",
27
+ ).to(self.model.device)
28
+
29
+ with torch.no_grad():
30
+ outputs = self.model.generate(
31
+ **inputs,
32
+ max_new_tokens=max_length,
33
+ temperature=temperature,
34
+ top_p=top_p,
35
+ do_sample=True,
36
+ eos_token_id=self.tokenizer.convert_tokens_to_ids(self.dataset_config['END_TOKEN']),
37
+ pad_token_id=self.tokenizer.eos_token_id
38
+ )
39
+
40
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
41
+
42
+ if self.dataset_config['ASSISTANT_TOKEN'] in generated_text:
43
+ generated_code = generated_text.split(self.dataset_config['ASSISTANT_TOKEN'])[1].strip()
44
+ if self.dataset_config['END_TOKEN'] in generated_code:
45
+ generated_code = generated_code.split(self.dataset_config['END_TOKEN'])[0].strip()
46
+ else:
47
+ generated_code = generated_text
48
+ return {"response": generated_code}
49
+
50
+ except Exception as e:
51
+ return {"error": str(e)}
52
+
53
+
54
+ if __name__ == "__main__":
55
+ server = ls.LitServer(LLMApi(), accelerator="auto")
56
+ server.run()