prashanthbsp commited on
Commit
4225454
·
1 Parent(s): 974d5cf

add custom handler

Browse files
Files changed (1) hide show
  1. handler.py +28 -39
handler.py CHANGED
@@ -1,36 +1,25 @@
1
  from typing import Dict, List, Any
2
- from unsloth import FastLanguageModel
3
- class EndpointHandler():
4
- def __init__(self, path="prashanthbsp/DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit-reasoning-cpg-entity-v1"):
5
- # Preload all the elements you are going to need at inference.
6
- # pseudo:
7
- # self.model= load_model(path)
8
- max_seq_length = 2048
9
- dtype = None
10
- load_in_4bit = True
11
- model, tokenizer = FastLanguageModel.from_pretrained(
12
- model_name = path,
13
- max_seq_length = max_seq_length,
14
- dtype = dtype,
15
- load_in_4bit = load_in_4bit,
16
- )
17
- self.model = model
18
- self.tokenizer = tokenizer
19
 
20
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
 
 
 
 
 
21
  """
22
- data args:
23
- inputs (:obj: `str` | `PIL.Image` | `np.array`)
24
- kwargs
25
- Return:
26
- A :obj:`list` | `dict`: will be serialized and returned
27
  """
28
-
29
- # pseudo
30
- # self.model(input)
31
  inputs = data.pop("inputs", data)
32
  context = inputs.pop("context", inputs)
33
- prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
 
 
34
  Write a response that appropriately completes the request.
35
  Before answering, think carefully about the task to ensure a logical and accurate response.
36
 
@@ -65,16 +54,16 @@ class EndpointHandler():
65
  }}
66
 
67
  ### Social Media Post:
68
- {0}
69
  ### Response:
70
- <think>{1}"""
71
- FastLanguageModel.for_inference(model)
72
- inputs = tokenizer([prompt_style.format(context, "")], return_tensors="pt").to("cuda")
73
- outputs = model.generate(
74
- input_ids=inputs.input_ids,
75
- attention_mask=inputs.attention_mask,
76
- max_new_tokens=1200,
77
- use_cache=True,
78
- )
79
- response = tokenizer.batch_decode(outputs)
80
- return response[0].split("### Response:")[1]
 
1
  from typing import Dict, List, Any
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ class EndpointHandler:
5
+ def __init__(self, path="prashanthbsp/reasoning-cpg-entity-v1"):
6
+ # Standard HF model loading - compatible with TGI
7
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
8
+ # Model is loaded by the TGI server, not by the handler
9
+
10
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
11
  """
12
+ data args:
13
+ inputs: text or dict containing text
14
+ Return:
15
+ A dict with the model's response
 
16
  """
17
+ # Extract inputs
 
 
18
  inputs = data.pop("inputs", data)
19
  context = inputs.pop("context", inputs)
20
+
21
+ # Format prompt according to your requirements
22
+ prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context.
23
  Write a response that appropriately completes the request.
24
  Before answering, think carefully about the task to ensure a logical and accurate response.
25
 
 
54
  }}
55
 
56
  ### Social Media Post:
57
+ {context}
58
  ### Response:
59
+ <think>"""
60
+
61
+ # For TGI, we return a dict with the prompt and generation params
62
+ return {
63
+ "inputs": prompt,
64
+ "parameters": {
65
+ "max_new_tokens": 1200,
66
+ "do_sample": False,
67
+ "return_full_text": False # Only return the generated text, not the prompt
68
+ }
69
+ }