File size: 2,874 Bytes
d980e56
 
 
7ae1120
 
 
d980e56
 
7ae1120
 
d980e56
 
 
7ae1120
d980e56
25a4b81
d980e56
25a4b81
 
 
d980e56
 
7ae1120
 
 
d980e56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# handler.py
import io, os, base64, requests, torch
from PIL import Image
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

class EndpointHandler:
    def __init__(self, path=""):
        # Optionnel : quantif 4-bit via variable d'env
        load_4bit = os.getenv("LOAD_IN_4BIT", "0") == "1"
        qcfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) if load_4bit else None

        # ⚠️ Utiliser le path local fourni par l'endpoint
        # ⚠️ Forcer l'implémentation d'attention "eager" (pas SDPA, pas FlashAttention)
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            quantization_config=qcfg,
            device_map="auto",
            multimodal_max_length=int(os.getenv("MULTIMODAL_MAX_LENGTH", "8192")),
            attn_implementation="eager",
            llm_attn_implementation="eager",  # ce champ est lu par certaines archis VLM
        )
        self.txt_tok = self.model.get_text_tokenizer()
        self.vis_tok = self.model.get_visual_tokenizer()

    def _load_image(self, spec):
        if "url" in spec:
            r = requests.get(spec["url"], timeout=10); r.raise_for_status()
            return Image.open(io.BytesIO(r.content)).convert("RGB")
        if "base64" in spec:
            return Image.open(io.BytesIO(base64.b64decode(spec["base64"]))).convert("RGB")
        raise ValueError("image must have 'url' or 'base64'")

    def __call__(self, data):
        prompt = data.get("prompt", "")
        imgs_spec = data.get("images", [])
        max_new = int(data.get("max_new_tokens", 512))

        images = [self._load_image(s) for s in imgs_spec]
        if images:
            prefix = "\n".join(["<image>"] * len(images))
            query = f"{prefix}\n{prompt}"
            max_part = 4 if len(images) > 1 else 9
        else:
            query, max_part = prompt, None

        prompt, input_ids, pix = self.model.preprocess_inputs(query, images, max_partition=max_part)
        attn = (input_ids != self.txt_tok.pad_token_id).unsqueeze(0).to(self.model.device)
        input_ids = input_ids.unsqueeze(0).to(self.model.device)
        pix = [pix.to(dtype=self.vis_tok.dtype, device=self.vis_tok.device)] if pix is not None else None

        with torch.inference_mode():
            out_ids = self.model.generate(
                input_ids,
                pixel_values=pix,
                attention_mask=attn,
                max_new_tokens=max_new,
                do_sample=False,
                use_cache=True,
                eos_token_id=self.model.generation_config.eos_token_id,
                pad_token_id=self.txt_tok.pad_token_id,
            )[0]
        text = self.txt_tok.decode(out_ids, skip_special_tokens=True)
        return {"output": text}