sainikhiljuluri commited on
Commit
5b24ffc
·
verified ·
1 Parent(s): 05f6a4c

Add handler.py for Inference Endpoints

Browse files
Files changed (1) hide show
  1. handler.py +17 -58
handler.py CHANGED
@@ -1,27 +1,14 @@
1
- """
2
- Custom Handler for DeepSeek-R1-Cybersecurity-8B-Merged
3
- HuggingFace Inference Endpoints
4
- """
5
-
6
  from typing import Dict, Any
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
-
11
  class EndpointHandler:
12
  def __init__(self, path: str = ""):
13
- """Initialize the model and tokenizer."""
14
- # Load tokenizer
15
- self.tokenizer = AutoTokenizer.from_pretrained(
16
- path,
17
- trust_remote_code=True
18
- )
19
-
20
- # Set pad token if not set
21
  if self.tokenizer.pad_token is None:
22
  self.tokenizer.pad_token = self.tokenizer.eos_token
23
 
24
- # Load model
25
  self.model = AutoModelForCausalLM.from_pretrained(
26
  path,
27
  torch_dtype=torch.bfloat16,
@@ -29,43 +16,18 @@ class EndpointHandler:
29
  trust_remote_code=True
30
  )
31
  self.model.eval()
32
-
33
- # Get device
34
  self.device = next(self.model.parameters()).device
35
- print(f"Model loaded on device: {self.device}")
36
 
37
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
38
- """Process inference request."""
39
- # Extract inputs
40
  inputs = data.get("inputs", data.get("input", ""))
41
- parameters = data.get("parameters", {})
42
-
43
- # Handle both string and list inputs
44
- if isinstance(inputs, str):
45
- prompts = [inputs]
46
- else:
47
- prompts = inputs
48
-
49
- # Default generation parameters
50
- generation_config = {
51
- "max_new_tokens": parameters.get("max_new_tokens", 256),
52
- "temperature": parameters.get("temperature", 0.7),
53
- "top_p": parameters.get("top_p", 0.9),
54
- "top_k": parameters.get("top_k", 50),
55
- "do_sample": parameters.get("do_sample", True),
56
- "repetition_penalty": parameters.get("repetition_penalty", 1.1),
57
- "pad_token_id": self.tokenizer.pad_token_id,
58
- "eos_token_id": self.tokenizer.eos_token_id,
59
- }
60
-
61
- # Remove None values
62
- generation_config = {k: v for k, v in generation_config.items() if v is not None}
63
 
64
  # Tokenize
65
  encoded = self.tokenizer(
66
- prompts,
67
  return_tensors="pt",
68
- padding=True,
69
  truncation=True,
70
  max_length=2048
71
  ).to(self.device)
@@ -74,20 +36,17 @@ class EndpointHandler:
74
  with torch.no_grad():
75
  outputs = self.model.generate(
76
  **encoded,
77
- **generation_config
 
 
 
 
 
 
78
  )
79
 
80
- # Decode
81
- generated_texts = []
82
- for i, output in enumerate(outputs):
83
- # Remove the input tokens from the output
84
- input_length = encoded["input_ids"][i].shape[0]
85
- generated_tokens = output[input_length:]
86
- text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
87
- generated_texts.append(text)
88
 
89
- # Return single string if single input, else list
90
- if isinstance(inputs, str):
91
- return {"generated_text": generated_texts[0]}
92
- else:
93
- return {"generated_text": generated_texts}
 
 
 
 
 
 
1
  from typing import Dict, Any
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
 
5
  class EndpointHandler:
6
  def __init__(self, path: str = ""):
7
+ """Initialize model and tokenizer."""
8
+ self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
 
 
 
 
 
 
9
  if self.tokenizer.pad_token is None:
10
  self.tokenizer.pad_token = self.tokenizer.eos_token
11
 
 
12
  self.model = AutoModelForCausalLM.from_pretrained(
13
  path,
14
  torch_dtype=torch.bfloat16,
 
16
  trust_remote_code=True
17
  )
18
  self.model.eval()
 
 
19
  self.device = next(self.model.parameters()).device
20
+ print(f"Model loaded on {self.device}")
21
 
22
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
23
+ """Handle inference request."""
 
24
  inputs = data.get("inputs", data.get("input", ""))
25
+ params = data.get("parameters", {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Tokenize
28
  encoded = self.tokenizer(
29
+ inputs,
30
  return_tensors="pt",
 
31
  truncation=True,
32
  max_length=2048
33
  ).to(self.device)
 
36
  with torch.no_grad():
37
  outputs = self.model.generate(
38
  **encoded,
39
+ max_new_tokens=params.get("max_new_tokens", 256),
40
+ temperature=params.get("temperature", 0.7),
41
+ top_p=params.get("top_p", 0.9),
42
+ do_sample=params.get("do_sample", True),
43
+ repetition_penalty=params.get("repetition_penalty", 1.1),
44
+ pad_token_id=self.tokenizer.pad_token_id,
45
+ eos_token_id=self.tokenizer.eos_token_id,
46
  )
47
 
48
+ # Decode (remove input tokens)
49
+ generated = outputs[0][encoded["input_ids"].shape[1]:]
50
+ text = self.tokenizer.decode(generated, skip_special_tokens=True)
 
 
 
 
 
51
 
52
+ return {"generated_text": text}