vignesh0007 commited on
Commit
4c81d8f
·
verified ·
1 Parent(s): 6960990

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +48 -46
handler.py CHANGED
@@ -1,46 +1,48 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
2
- from peft import get_peft_model, LoraConfig
3
- from safetensors.torch import load_file
4
- from huggingface_hub import hf_hub_download
5
- import torch
6
-
7
- class EndpointHandler:
8
- def __init__(self, path=""):
9
- self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
10
- base_model = AutoModelForCausalLM.from_pretrained(
11
- "meta-llama/Llama-2-7b-hf",
12
- torch_dtype=torch.float16,
13
- device_map="auto"
14
- )
15
-
16
- lora_config = LoraConfig(
17
- r=8,
18
- lora_alpha=32,
19
- target_modules=["q_proj"],
20
- lora_dropout=0.05,
21
- bias="none",
22
- task_type="CAUSAL_LM"
23
- )
24
-
25
- self.model = get_peft_model(base_model, lora_config)
26
- adapter_path = hf_hub_download(
27
- repo_id="vignesh0007/Anime-Gen-Llama-2-7B",
28
- filename="adapter_model.safetensors",
29
- repo_type="model"
30
- )
31
- lora_state = load_file(adapter_path)
32
- self.model.load_state_dict(lora_state, strict=False)
33
- self.model.eval()
34
-
35
- def __call__(self, data):
36
- inputs = data.get("inputs", "")
37
- tokens = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
38
- with torch.no_grad():
39
- outputs = self.model.generate(
40
- **tokens,
41
- max_new_tokens=256,
42
- temperature=0.8,
43
- top_p=0.95,
44
- do_sample=True
45
- )
46
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ from peft import get_peft_model, LoraConfig
3
+ from safetensors.torch import load_file
4
+ from huggingface_hub import hf_hub_download
5
+ import torch
6
+
7
+ token = os.getenv("HUGGINGFACE_HUB_TOKEN")
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
12
+ base_model = AutoModelForCausalLM.from_pretrained(
13
+ "meta-llama/Llama-2-7b-hf",
14
+ torch_dtype=torch.float16,
15
+ device_map="auto"
16
+ )
17
+
18
+ lora_config = LoraConfig(
19
+ r=8,
20
+ lora_alpha=32,
21
+ target_modules=["q_proj"],
22
+ lora_dropout=0.05,
23
+ bias="none",
24
+ task_type="CAUSAL_LM"
25
+ )
26
+
27
+ self.model = get_peft_model(base_model, lora_config)
28
+ adapter_path = hf_hub_download(
29
+ repo_id="vignesh0007/Anime-Gen-Llama-2-7B",
30
+ filename="adapter_model.safetensors",
31
+ repo_type="model"
32
+ )
33
+ lora_state = load_file(adapter_path)
34
+ self.model.load_state_dict(lora_state, strict=False)
35
+ self.model.eval()
36
+
37
+ def __call__(self, data):
38
+ inputs = data.get("inputs", "")
39
+ tokens = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
40
+ with torch.no_grad():
41
+ outputs = self.model.generate(
42
+ **tokens,
43
+ max_new_tokens=256,
44
+ temperature=0.8,
45
+ top_p=0.95,
46
+ do_sample=True
47
+ )
48
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)