subhankarfynd commited on
Commit
ffcd623
·
1 Parent(s): 968af1c

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +69 -0
handler.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import bitsandbytes as bnb
3
+ import torch
4
+ import transformers
5
+ from datasets import load_dataset
6
+ from typing import Dict, List, Any
7
+ from peft import (
8
+ LoraConfig,
9
+ PeftConfig,
10
+ PeftModel,
11
+ get_peft_model,
12
+ prepare_model_for_kbit_training,
13
+ )
14
+ from transformers import (
15
+ AutoConfig,
16
+ LlamaTokenizer,
17
+ LlamaForCausalLM,
18
+ #AutoModelForCausalLM,
19
+ #AutoTokenizer,
20
+ BitsAndBytesConfig,
21
+ )
22
+ import json
23
+
24
+ bnb_config = BitsAndBytesConfig(
25
+ load_in_4bit=True,
26
+ bnb_4bit_use_double_quant=True,
27
+ bnb_4bit_quant_type="nf4",
28
+ bnb_4bit_compute_dtype=torch.bfloat16,
29
+ )
30
+
31
+
32
+ from huggingface_hub import login
33
+ access_token_read = "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"
34
+ login(token = access_token_read)
35
+
36
+
37
+ class EndpointHandler:
38
+ def __init__(self, path=''):
39
+ PEFT_MODEL = path
40
+ config = PeftConfig.from_pretrained(PEFT_MODEL)
41
+ self.model = LlamaForCausalLM.from_pretrained(
42
+ config.base_model_name_or_path,
43
+ return_dict=True,
44
+ quantization_config=bnb_config,
45
+ device_map="auto",
46
+ trust_remote_code=True,
47
+ )
48
+ self.tokenizer = LlamaTokenizer.from_pretrained(config.base_model_name_or_path)
49
+ self.tokenizer.pad_token_id = (0)
50
+ self.tokenizer.padding_side = "left"
51
+ self.model = PeftModel.from_pretrained(self.model, PEFT_MODEL)
52
+ self.generation_config = self.model.generation_config
53
+ self.generation_config.max_new_tokens = 100
54
+ self.generation_config.pad_token_id = self.tokenizer.eos_token_id
55
+ self.generation_config.eos_token_id = self.tokenizer.eos_token_id
56
+
57
+
58
+ def __call__(self, data: Dict[str, Any]):
59
+ prompt = data.pop("inputs", data)
60
+ DEVICE = "cuda:0"
61
+ input_message = f"""[INST]You are an assistant that detects the intent and entity of user's message. Possible entity stores are JioMart, JioFiber, JioCinema and Tira Beauty. Detect the intent and entity of the following user's message[/INST]\nUser: {prompt}\nAssistant: """.strip()
62
+ encoding = self.tokenizer(input_message, return_tensors="pt").to(DEVICE)
63
+ with torch.inference_mode():
64
+ outputs = self.model.generate(
65
+ input_ids=encoding.input_ids,
66
+ attention_mask=encoding.attention_mask,
67
+ generation_config=self.generation_config
68
+ )
69
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)[len(input_message):]