echeyde commited on
Commit
9784a5c
·
verified ·
1 Parent(s): 4b63e65

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +74 -37
handler.py CHANGED
@@ -1,45 +1,82 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
2
  import torch
3
 
 
4
  class EndpointHandler:
5
- def __init__(self):
6
- self.tokenizer = None
7
- self.model = None
8
-
9
- def load_model(self):
10
- """Load model and tokenizer"""
11
- self.tokenizer = AutoTokenizer.from_pretrained(".")
12
- self.model = AutoModelForCausalLM.from_pretrained(
13
- ".",
14
- torch_dtype=torch.float16,
15
- device_map="auto",
16
- use_safetensors=True
17
- )
18
 
19
- def __call__(self, data):
20
- """Inference request handler"""
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  inputs = data.pop("inputs", data)
22
  parameters = data.pop("parameters", {})
23
-
24
- # Set default parameters if not provided
25
- max_length = parameters.get("max_length", 100)
26
- temperature = parameters.get("temperature", 0.7)
27
- top_p = parameters.get("top_p", 0.9)
28
-
29
- # Tokenize inputs
30
- input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to(self.model.device)
31
-
32
- # Generate
 
 
 
 
 
 
 
 
 
 
 
33
  with torch.no_grad():
34
- outputs = self.model.generate(
35
- input_ids,
36
- max_length=max_length,
37
- temperature=temperature,
38
- top_p=top_p,
39
- pad_token_id=self.tokenizer.pad_token_id,
40
- eos_token_id=self.tokenizer.eos_token_id,
41
  )
42
-
43
- # Decode and return response
44
- generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
45
- return {"generated_text": generated_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
+
6
  class EndpointHandler:
7
+ def __init__(self, path=""):
8
+ # Initialize model and tokenizer
9
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
10
+ self.model = AutoModelForCausalLM.from_pretrained(path)
11
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ self.model.to(self.device)
 
 
 
 
 
 
 
13
 
14
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
+ """
16
+ Args:
17
+ data: JSON input with structure:
18
+ {
19
+ "inputs": "your text prompt here",
20
+ "parameters": {
21
+ "max_new_tokens": 50,
22
+ "temperature": 0.7,
23
+ "top_p": 0.9,
24
+ "do_sample": true
25
+ }
26
+ }
27
+ """
28
+ # Get input text and parameters
29
  inputs = data.pop("inputs", data)
30
  parameters = data.pop("parameters", {})
31
+
32
+ # Default generation parameters
33
+ generation_config = {
34
+ "max_new_tokens": parameters.get("max_new_tokens", 50),
35
+ "temperature": parameters.get("temperature", 0.7),
36
+ "top_p": parameters.get("top_p", 0.9),
37
+ "do_sample": parameters.get("do_sample", True),
38
+ "pad_token_id": self.tokenizer.eos_token_id,
39
+ "num_return_sequences": parameters.get("num_return_sequences", 1)
40
+ }
41
+
42
+ # Tokenize
43
+ inputs = self.tokenizer(
44
+ inputs,
45
+ return_tensors="pt",
46
+ padding=True,
47
+ truncation=True,
48
+ max_length=512
49
+ ).to(self.device)
50
+
51
+ # Generate text
52
  with torch.no_grad():
53
+ generated_ids = self.model.generate(
54
+ inputs.input_ids,
55
+ attention_mask=inputs.attention_mask,
56
+ **generation_config
 
 
 
57
  )
58
+
59
+ # Decode and return generated text
60
+ generated_texts = self.tokenizer.batch_decode(
61
+ generated_ids,
62
+ skip_special_tokens=True
63
+ )
64
+
65
+ return {
66
+ "generated_text": generated_texts[0], # Return first generation if multiple
67
+ "all_generations": generated_texts # All generations if num_return_sequences > 1
68
+ }
69
+
70
+ def preprocess(self, data):
71
+ """
72
+ Handle different input formats
73
+ """
74
+ if isinstance(data, str):
75
+ return {"inputs": data}
76
+ return data
77
+
78
+ def postprocess(self, data):
79
+ """
80
+ Clean up output if needed
81
+ """
82
+ return data