Mir-2002 commited on
Commit
becbf5f
·
verified ·
1 Parent(s): 7fa5830

Delete handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -89
handler.py DELETED
@@ -1,89 +0,0 @@
1
- from typing import Any, Dict, List
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- import torch
4
- import os
5
-
6
- MAX_INPUT_LENGTH = 256
7
- MAX_OUTPUT_LENGTH = 128
8
-
9
- class EndpointHandler:
10
- def __init__(self, model_dir: str = "", num_threads: int | None = None, generation_config: Dict[str, Any] | None = None, **kwargs: Any) -> None:
11
- # Set environment hints for CPU efficiency
12
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
13
-
14
- # Configure torch threading for CPU
15
- if num_threads:
16
- try:
17
- torch.set_num_threads(num_threads)
18
- torch.set_num_interop_threads(max(1, num_threads // 2))
19
- except Exception:
20
- pass
21
- os.environ.setdefault("OMP_NUM_THREADS", str(num_threads))
22
- os.environ.setdefault("MKL_NUM_THREADS", str(num_threads))
23
-
24
- self.device = "cpu" # Force CPU usage
25
-
26
- # Load tokenizer & model with CPU-friendly settings
27
- self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
28
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir, low_cpu_mem_usage=True)
29
- self.model.eval()
30
- self.model.to(self.device)
31
-
32
- # Optional bfloat16 cast on CPU (beneficial on Sapphire Rapids/oneDNN)
33
- self._use_bf16 = False
34
- if os.getenv("ENABLE_BF16", "1") == "1":
35
- try:
36
- self.model = self.model.to(dtype=torch.bfloat16)
37
- self._use_bf16 = True
38
- except Exception:
39
- self._use_bf16 = False
40
-
41
- # Determine a safe pad token id
42
- pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
43
-
44
- # Default fast generation config (greedy) overridable by caller
45
- default_gen = {
46
- "max_length": MAX_OUTPUT_LENGTH,
47
- "num_beams": 1, # Greedy for CPU speed
48
- "do_sample": False,
49
- "no_repeat_ngram_size": 3,
50
- "early_stopping": True,
51
- "use_cache": True,
52
- "pad_token_id": pad_id,
53
- }
54
- if generation_config:
55
- default_gen.update(generation_config)
56
- self.generation_args = default_gen
57
-
58
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
59
- inputs = data.get("inputs")
60
- if not inputs:
61
- raise ValueError("No 'inputs' found in the request data.")
62
-
63
- if isinstance(inputs, str):
64
- inputs = [inputs]
65
-
66
- # Allow per-request overrides under 'parameters'
67
- per_request_params = data.get("parameters") or {}
68
- gen_args = {**self.generation_args, **per_request_params}
69
-
70
- tokenized_inputs = self.tokenizer(
71
- inputs,
72
- max_length=MAX_INPUT_LENGTH,
73
- padding=True,
74
- truncation=True,
75
- return_tensors="pt"
76
- ).to(self.device)
77
-
78
- try:
79
- with torch.inference_mode():
80
- outputs = self.model.generate(
81
- tokenized_inputs["input_ids"],
82
- attention_mask=tokenized_inputs["attention_mask"],
83
- **gen_args
84
- )
85
- decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
86
- results = [{"generated_text": text} for text in decoded_outputs]
87
- return results
88
- except Exception as e:
89
- return [{"generated_text": f"Error: {str(e)}"}]