Mir-2002 commited on
Commit
f2f11b6
·
verified ·
1 Parent(s): 832a366

Delete handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -49
handler.py DELETED
@@ -1,49 +0,0 @@
1
- from typing import Any, Dict, List
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- import torch
4
-
5
- MAX_INPUT_LENGTH = 256
6
- MAX_OUTPUT_LENGTH = 128
7
-
8
- class EndpointHandler:
9
- def __init__(self, model_dir: str = "", **kwargs: Any) -> None:
10
- self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
11
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
12
- self.model.eval()
13
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
- self.model.to(self.device)
15
-
16
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
17
- inputs = data.get("inputs")
18
- if not inputs:
19
- raise ValueError("No 'inputs' found in the request data.")
20
-
21
- if isinstance(inputs, str):
22
- inputs = [inputs]
23
-
24
- tokenized_inputs = self.tokenizer(
25
- inputs,
26
- max_length=MAX_INPUT_LENGTH,
27
- padding=True,
28
- truncation=True,
29
- return_tensors="pt"
30
- ).to(self.device)
31
-
32
- try:
33
- with torch.no_grad():
34
- outputs = self.model.generate(
35
- tokenized_inputs["input_ids"],
36
- attention_mask=tokenized_inputs["attention_mask"],
37
- max_length=MAX_OUTPUT_LENGTH,
38
- num_beams=4, # Slightly faster
39
- no_repeat_ngram_size=3,
40
- early_stopping=True,
41
- do_sample=False,
42
- pad_token_id=self.tokenizer.pad_token_id
43
- )
44
- decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
45
- results = [{"generated_text": text} for text in decoded_outputs]
46
- return results
47
- except Exception as e:
48
- # Log error and return a message
49
- return [{"generated_text": f"Error: {str(e)}"}]