algorythmtechnologies commited on
Commit
199308a
·
verified ·
1 Parent(s): 96f0f39

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +140 -0
handler.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional
2
+
3
+ import torch
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForCausalLM,
7
+ TextIteratorStreamer,
8
+ )
9
+
10
+
11
+ class EndpointHandler:
12
+ """
13
+ Custom Inference Endpoints handler for algorythmtechnologies/Warren-8B-Uncensored-2000.
14
+
15
+ Expected JSON payload:
16
+ {
17
+ "inputs": "user prompt or message",
18
+ "max_new_tokens": 256, # optional
19
+ "temperature": 0.7, # optional
20
+ "top_p": 0.9, # optional
21
+ "top_k": 50, # optional
22
+ "repetition_penalty": 1.1, # optional
23
+ "stop_sequences": ["</s>"] # optional
24
+ }
25
+
26
+ Returns:
27
+ [
28
+ {
29
+ "generated_text": "...",
30
+ "finish_reason": "length|stop|error"
31
+ }
32
+ ]
33
+ """
34
+
35
+ def __init__(self, path: str = ""):
36
+ # Choose device
37
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+ # Load tokenizer and model from the repository path
40
+ self.tokenizer = AutoTokenizer.from_pretrained(path or ".")
41
+ # Make sure there is a pad_token for generation
42
+ if self.tokenizer.pad_token is None:
43
+ self.tokenizer.pad_token = self.tokenizer.eos_token
44
+
45
+ self.model = AutoModelForCausalLM.from_pretrained(
46
+ path or ".",
47
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
48
+ device_map="auto" if self.device == "cuda" else None,
49
+ )
50
+
51
+ # Set model to eval mode
52
+ self.model.eval()
53
+
54
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
55
+ """
56
+ data args:
57
+ inputs (str): user text prompt
58
+ max_new_tokens (int, optional)
59
+ temperature (float, optional)
60
+ top_p (float, optional)
61
+ top_k (int, optional)
62
+ repetition_penalty (float, optional)
63
+ stop_sequences (List[str], optional)
64
+
65
+ Return:
66
+ A list with one dict:
67
+ [
68
+ {
69
+ "generated_text": str,
70
+ "finish_reason": str
71
+ }
72
+ ]
73
+ """
74
+ # Extract inputs
75
+ prompt: Optional[str] = data.get("inputs")
76
+ if prompt is None:
77
+ return [{"error": "Missing 'inputs' field in payload."}]
78
+
79
+ max_new_tokens: int = int(data.get("max_new_tokens", 256))
80
+ temperature: float = float(data.get("temperature", 0.7))
81
+ top_p: float = float(data.get("top_p", 0.9))
82
+ top_k: int = int(data.get("top_k", 50))
83
+ repetition_penalty: float = float(data.get("repetition_penalty", 1.05))
84
+ stop_sequences = data.get("stop_sequences", None)
85
+
86
+ # Tokenize
87
+ inputs = self.tokenizer(
88
+ prompt,
89
+ return_tensors="pt",
90
+ padding=False,
91
+ truncation=True,
92
+ ).to(self.device)
93
+
94
+ # Configure basic generation kwargs
95
+ gen_kwargs = dict(
96
+ max_new_tokens=max_new_tokens,
97
+ do_sample=True,
98
+ temperature=temperature,
99
+ top_p=top_p,
100
+ top_k=top_k,
101
+ repetition_penalty=repetition_penalty,
102
+ pad_token_id=self.tokenizer.pad_token_id,
103
+ eos_token_id=self.tokenizer.eos_token_id,
104
+ )
105
+
106
+ # Run generation
107
+ with torch.no_grad():
108
+ output_ids = self.model.generate(
109
+ **inputs,
110
+ **gen_kwargs,
111
+ )
112
+
113
+ # Decode full text and strip the original prompt
114
+ full_text = self.tokenizer.decode(
115
+ output_ids[0],
116
+ skip_special_tokens=True,
117
+ )
118
+
119
+ # Try to remove the prompt from the beginning for cleaner output
120
+ if full_text.startswith(prompt):
121
+ generated_text = full_text[len(prompt) :].lstrip()
122
+ else:
123
+ generated_text = full_text
124
+
125
+ # Apply stop sequences post-hoc if provided
126
+ finish_reason = "length"
127
+ if stop_sequences:
128
+ for stop in stop_sequences:
129
+ idx = generated_text.find(stop)
130
+ if idx != -1:
131
+ generated_text = generated_text[:idx]
132
+ finish_reason = "stop"
133
+ break
134
+
135
+ return [
136
+ {
137
+ "generated_text": generated_text,
138
+ "finish_reason": finish_reason,
139
+ }
140
+ ]