Create handler.py
Browse files- handler.py +198 -0
handler.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import base64
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
from typing import Dict, List, Any, Optional, Union
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig
|
| 7 |
+
|
| 8 |
+
class EndpointHandler():
|
| 9 |
+
def __init__(self, path=""):
|
| 10 |
+
"""
|
| 11 |
+
Initialize the model handler
|
| 12 |
+
Args:
|
| 13 |
+
path: Path to the model weights (provided by HF Endpoints)
|
| 14 |
+
"""
|
| 15 |
+
# Configuration pour la quantification 4-bit
|
| 16 |
+
self.bnb_config = BitsAndBytesConfig(
|
| 17 |
+
load_in_4bit=True,
|
| 18 |
+
bnb_4bit_quant_type="nf4",
|
| 19 |
+
bnb_4bit_use_double_quant=True,
|
| 20 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
# Chargement du processeur et du modèle
|
| 24 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 25 |
+
path,
|
| 26 |
+
trust_remote_code=True
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
self.model = AutoModelForVision2Seq.from_pretrained(
|
| 30 |
+
path,
|
| 31 |
+
quantization_config=self.bnb_config,
|
| 32 |
+
device_map="auto",
|
| 33 |
+
trust_remote_code=True,
|
| 34 |
+
torch_dtype=torch.float16
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Configuration de génération par défaut
|
| 38 |
+
self.default_generation_config = {
|
| 39 |
+
"max_new_tokens": 512,
|
| 40 |
+
"temperature": 0.7,
|
| 41 |
+
"top_p": 0.9,
|
| 42 |
+
"do_sample": True,
|
| 43 |
+
"repetition_penalty": 1.1
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 47 |
+
"""
|
| 48 |
+
Process the incoming request
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
data: Dictionary containing:
|
| 52 |
+
- inputs (str or dict): Text prompt or structured input
|
| 53 |
+
- image (str, optional): Base64 encoded image
|
| 54 |
+
- parameters (dict, optional): Generation parameters
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
List containing the response dictionary
|
| 58 |
+
"""
|
| 59 |
+
# Extraction des données d'entrée
|
| 60 |
+
inputs = data.get("inputs", "")
|
| 61 |
+
image_data = data.get("image", None)
|
| 62 |
+
parameters = data.get("parameters", {})
|
| 63 |
+
|
| 64 |
+
# Fusion des paramètres de génération
|
| 65 |
+
generation_config = {**self.default_generation_config, **parameters}
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
# Traitement selon le type d'entrée
|
| 69 |
+
if isinstance(inputs, str):
|
| 70 |
+
# Entrée texte simple
|
| 71 |
+
response = self._process_text(inputs, generation_config)
|
| 72 |
+
elif isinstance(inputs, dict):
|
| 73 |
+
# Entrée structurée (avec potentiellement une image)
|
| 74 |
+
text = inputs.get("text", "")
|
| 75 |
+
image_b64 = inputs.get("image", image_data)
|
| 76 |
+
|
| 77 |
+
if image_b64:
|
| 78 |
+
response = self._process_multimodal(text, image_b64, generation_config)
|
| 79 |
+
else:
|
| 80 |
+
response = self._process_text(text, generation_config)
|
| 81 |
+
else:
|
| 82 |
+
# Si l'image est fournie séparément
|
| 83 |
+
if image_data:
|
| 84 |
+
response = self._process_multimodal(str(inputs), image_data, generation_config)
|
| 85 |
+
else:
|
| 86 |
+
response = self._process_text(str(inputs), generation_config)
|
| 87 |
+
|
| 88 |
+
return [{"generated_text": response}]
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
return [{"error": str(e)}]
|
| 92 |
+
|
| 93 |
+
def _process_text(self, text: str, generation_config: dict) -> str:
|
| 94 |
+
"""
|
| 95 |
+
Process text-only input
|
| 96 |
+
"""
|
| 97 |
+
# Construction du message avec un prompt optimisé
|
| 98 |
+
messages = [
|
| 99 |
+
{"role": "system", "content": "You are an expert assistant in mathematics and sciences. Provide clear, precise, and pedagogical answers. For each problem, explain your reasoning step by step, justify your choices, and illustrate with examples when necessary. Adopt an accessible yet rigorous style."},
|
| 100 |
+
{"role": "user", "content": [
|
| 101 |
+
{"type": "text", "text": text}
|
| 102 |
+
]}
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
# Préparation de l'input
|
| 106 |
+
text_inputs = self.processor.apply_chat_template(
|
| 107 |
+
messages,
|
| 108 |
+
tokenize=True,
|
| 109 |
+
add_generation_prompt=True,
|
| 110 |
+
return_tensors="pt"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Déplacement sur le bon device
|
| 114 |
+
text_inputs = text_inputs.to(self.model.device)
|
| 115 |
+
|
| 116 |
+
# Génération
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
outputs = self.model.generate(
|
| 119 |
+
text_inputs,
|
| 120 |
+
**generation_config,
|
| 121 |
+
pad_token_id=self.processor.tokenizer.eos_token_id,
|
| 122 |
+
eos_token_id=self.processor.tokenizer.eos_token_id
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Décodage
|
| 126 |
+
response = self.processor.decode(outputs[0], skip_special_tokens=True)
|
| 127 |
+
|
| 128 |
+
# Extraction de la réponse (retirer le prompt)
|
| 129 |
+
if "assistant" in response:
|
| 130 |
+
response = response.split("assistant")[-1].strip()
|
| 131 |
+
|
| 132 |
+
return response
|
| 133 |
+
|
| 134 |
+
def _process_multimodal(self, text: str, image_b64: str, generation_config: dict) -> str:
|
| 135 |
+
"""
|
| 136 |
+
Process text and image input
|
| 137 |
+
"""
|
| 138 |
+
# Décodage de l'image base64
|
| 139 |
+
try:
|
| 140 |
+
image_bytes = base64.b64decode(image_b64)
|
| 141 |
+
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
| 142 |
+
except Exception as e:
|
| 143 |
+
raise ValueError(f"Erreur lors du décodage de l'image: {str(e)}")
|
| 144 |
+
|
| 145 |
+
# Construction du message multimodal
|
| 146 |
+
messages = [
|
| 147 |
+
{"role": "system", "content": "You are an expert assistant in mathematics and sciences with multimodal reasoning capabilities. Provide clear, precise, and pedagogical answers. For each problem, explain your reasoning step by step, justify your choices, and illustrate with examples, diagrams, or visual aids when necessary. Analyze both textual and visual information carefully, and present your explanations in an accessible yet rigorous style."},
|
| 148 |
+
{"role": "user", "content": [
|
| 149 |
+
{"type": "text", "text": text},
|
| 150 |
+
{"type": "image"}
|
| 151 |
+
]}
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
# Préparation du prompt
|
| 155 |
+
prompt = self.processor.apply_chat_template(
|
| 156 |
+
messages,
|
| 157 |
+
add_generation_prompt=True,
|
| 158 |
+
tokenize=False
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Traitement avec l'image
|
| 162 |
+
inputs = self.processor(
|
| 163 |
+
text=prompt,
|
| 164 |
+
images=[image],
|
| 165 |
+
return_tensors="pt"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Déplacement sur le bon device
|
| 169 |
+
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
| 170 |
+
|
| 171 |
+
# Génération
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
outputs = self.model.generate(
|
| 174 |
+
**inputs,
|
| 175 |
+
**generation_config,
|
| 176 |
+
pad_token_id=self.processor.tokenizer.eos_token_id,
|
| 177 |
+
eos_token_id=self.processor.tokenizer.eos_token_id
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Décodage
|
| 181 |
+
response = self.processor.decode(outputs[0], skip_special_tokens=True)
|
| 182 |
+
|
| 183 |
+
# Extraction de la réponse
|
| 184 |
+
if "assistant" in response:
|
| 185 |
+
response = response.split("assistant")[-1].strip()
|
| 186 |
+
|
| 187 |
+
return response
|
| 188 |
+
|
| 189 |
+
def health(self) -> Dict[str, Any]:
|
| 190 |
+
"""
|
| 191 |
+
Health check endpoint
|
| 192 |
+
"""
|
| 193 |
+
return {
|
| 194 |
+
"status": "healthy",
|
| 195 |
+
"model": "QwenStem-7b",
|
| 196 |
+
"device": str(self.model.device),
|
| 197 |
+
"quantization": "4-bit"
|
| 198 |
+
}
|