analist commited on
Commit
8bc1b5c
·
verified ·
1 Parent(s): 275007a

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +198 -0
handler.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import base64
3
+ from io import BytesIO
4
+ from typing import Dict, List, Any, Optional, Union
5
+ from PIL import Image
6
+ from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig
7
+
8
+ class EndpointHandler():
9
+ def __init__(self, path=""):
10
+ """
11
+ Initialize the model handler
12
+ Args:
13
+ path: Path to the model weights (provided by HF Endpoints)
14
+ """
15
+ # Configuration pour la quantification 4-bit
16
+ self.bnb_config = BitsAndBytesConfig(
17
+ load_in_4bit=True,
18
+ bnb_4bit_quant_type="nf4",
19
+ bnb_4bit_use_double_quant=True,
20
+ bnb_4bit_compute_dtype=torch.float16,
21
+ )
22
+
23
+ # Chargement du processeur et du modèle
24
+ self.processor = AutoProcessor.from_pretrained(
25
+ path,
26
+ trust_remote_code=True
27
+ )
28
+
29
+ self.model = AutoModelForVision2Seq.from_pretrained(
30
+ path,
31
+ quantization_config=self.bnb_config,
32
+ device_map="auto",
33
+ trust_remote_code=True,
34
+ torch_dtype=torch.float16
35
+ )
36
+
37
+ # Configuration de génération par défaut
38
+ self.default_generation_config = {
39
+ "max_new_tokens": 512,
40
+ "temperature": 0.7,
41
+ "top_p": 0.9,
42
+ "do_sample": True,
43
+ "repetition_penalty": 1.1
44
+ }
45
+
46
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
47
+ """
48
+ Process the incoming request
49
+
50
+ Args:
51
+ data: Dictionary containing:
52
+ - inputs (str or dict): Text prompt or structured input
53
+ - image (str, optional): Base64 encoded image
54
+ - parameters (dict, optional): Generation parameters
55
+
56
+ Returns:
57
+ List containing the response dictionary
58
+ """
59
+ # Extraction des données d'entrée
60
+ inputs = data.get("inputs", "")
61
+ image_data = data.get("image", None)
62
+ parameters = data.get("parameters", {})
63
+
64
+ # Fusion des paramètres de génération
65
+ generation_config = {**self.default_generation_config, **parameters}
66
+
67
+ try:
68
+ # Traitement selon le type d'entrée
69
+ if isinstance(inputs, str):
70
+ # Entrée texte simple
71
+ response = self._process_text(inputs, generation_config)
72
+ elif isinstance(inputs, dict):
73
+ # Entrée structurée (avec potentiellement une image)
74
+ text = inputs.get("text", "")
75
+ image_b64 = inputs.get("image", image_data)
76
+
77
+ if image_b64:
78
+ response = self._process_multimodal(text, image_b64, generation_config)
79
+ else:
80
+ response = self._process_text(text, generation_config)
81
+ else:
82
+ # Si l'image est fournie séparément
83
+ if image_data:
84
+ response = self._process_multimodal(str(inputs), image_data, generation_config)
85
+ else:
86
+ response = self._process_text(str(inputs), generation_config)
87
+
88
+ return [{"generated_text": response}]
89
+
90
+ except Exception as e:
91
+ return [{"error": str(e)}]
92
+
93
+ def _process_text(self, text: str, generation_config: dict) -> str:
94
+ """
95
+ Process text-only input
96
+ """
97
+ # Construction du message avec un prompt optimisé
98
+ messages = [
99
+ {"role": "system", "content": "You are an expert assistant in mathematics and sciences. Provide clear, precise, and pedagogical answers. For each problem, explain your reasoning step by step, justify your choices, and illustrate with examples when necessary. Adopt an accessible yet rigorous style."},
100
+ {"role": "user", "content": [
101
+ {"type": "text", "text": text}
102
+ ]}
103
+ ]
104
+
105
+ # Préparation de l'input
106
+ text_inputs = self.processor.apply_chat_template(
107
+ messages,
108
+ tokenize=True,
109
+ add_generation_prompt=True,
110
+ return_tensors="pt"
111
+ )
112
+
113
+ # Déplacement sur le bon device
114
+ text_inputs = text_inputs.to(self.model.device)
115
+
116
+ # Génération
117
+ with torch.no_grad():
118
+ outputs = self.model.generate(
119
+ text_inputs,
120
+ **generation_config,
121
+ pad_token_id=self.processor.tokenizer.eos_token_id,
122
+ eos_token_id=self.processor.tokenizer.eos_token_id
123
+ )
124
+
125
+ # Décodage
126
+ response = self.processor.decode(outputs[0], skip_special_tokens=True)
127
+
128
+ # Extraction de la réponse (retirer le prompt)
129
+ if "assistant" in response:
130
+ response = response.split("assistant")[-1].strip()
131
+
132
+ return response
133
+
134
+ def _process_multimodal(self, text: str, image_b64: str, generation_config: dict) -> str:
135
+ """
136
+ Process text and image input
137
+ """
138
+ # Décodage de l'image base64
139
+ try:
140
+ image_bytes = base64.b64decode(image_b64)
141
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
142
+ except Exception as e:
143
+ raise ValueError(f"Erreur lors du décodage de l'image: {str(e)}")
144
+
145
+ # Construction du message multimodal
146
+ messages = [
147
+ {"role": "system", "content": "You are an expert assistant in mathematics and sciences with multimodal reasoning capabilities. Provide clear, precise, and pedagogical answers. For each problem, explain your reasoning step by step, justify your choices, and illustrate with examples, diagrams, or visual aids when necessary. Analyze both textual and visual information carefully, and present your explanations in an accessible yet rigorous style."},
148
+ {"role": "user", "content": [
149
+ {"type": "text", "text": text},
150
+ {"type": "image"}
151
+ ]}
152
+ ]
153
+
154
+ # Préparation du prompt
155
+ prompt = self.processor.apply_chat_template(
156
+ messages,
157
+ add_generation_prompt=True,
158
+ tokenize=False
159
+ )
160
+
161
+ # Traitement avec l'image
162
+ inputs = self.processor(
163
+ text=prompt,
164
+ images=[image],
165
+ return_tensors="pt"
166
+ )
167
+
168
+ # Déplacement sur le bon device
169
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
170
+
171
+ # Génération
172
+ with torch.no_grad():
173
+ outputs = self.model.generate(
174
+ **inputs,
175
+ **generation_config,
176
+ pad_token_id=self.processor.tokenizer.eos_token_id,
177
+ eos_token_id=self.processor.tokenizer.eos_token_id
178
+ )
179
+
180
+ # Décodage
181
+ response = self.processor.decode(outputs[0], skip_special_tokens=True)
182
+
183
+ # Extraction de la réponse
184
+ if "assistant" in response:
185
+ response = response.split("assistant")[-1].strip()
186
+
187
+ return response
188
+
189
+ def health(self) -> Dict[str, Any]:
190
+ """
191
+ Health check endpoint
192
+ """
193
+ return {
194
+ "status": "healthy",
195
+ "model": "QwenStem-7b",
196
+ "device": str(self.model.device),
197
+ "quantization": "4-bit"
198
+ }