askcatalystai commited on
Commit
dab0caa
·
verified ·
1 Parent(s): d07c2dc

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +55 -0
handler.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path: str = ""):
7
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
8
+ self.model = AutoModelForCausalLM.from_pretrained(
9
+ path,
10
+ torch_dtype=torch.float16,
11
+ device_map="auto"
12
+ )
13
+
14
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
15
+ # Get inputs
16
+ inputs = data.get("inputs", "")
17
+ parameters = data.get("parameters", {})
18
+
19
+ # Extract parameters with defaults
20
+ max_new_tokens = parameters.get("max_new_tokens", 200)
21
+ temperature = parameters.get("temperature", 0.7)
22
+ top_p = parameters.get("top_p", 0.9)
23
+
24
+ # Format prompt if instruction/input provided separately
25
+ if isinstance(inputs, dict):
26
+ instruction = inputs.get("instruction", "")
27
+ product_details = inputs.get("product_details", "")
28
+ prompt = f"***Instruction: {instruction}\n***Input: {product_details}\n***Response:"
29
+ else:
30
+ prompt = inputs
31
+
32
+ # Tokenize
33
+ input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
34
+
35
+ # Generate
36
+ with torch.no_grad():
37
+ outputs = self.model.generate(
38
+ **input_ids,
39
+ max_new_tokens=max_new_tokens,
40
+ do_sample=True,
41
+ temperature=temperature,
42
+ top_p=top_p,
43
+ pad_token_id=self.tokenizer.eos_token_id
44
+ )
45
+
46
+ # Decode
47
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
48
+
49
+ # Extract response part
50
+ if "***Response:" in full_response:
51
+ response = full_response.split("***Response:")[1].strip()
52
+ else:
53
+ response = full_response
54
+
55
+ return {"generated_text": response}