File size: 1,966 Bytes
a198bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import Dict, Any
from transformers import JanusForConditionalGeneration, JanusProcessor
import torch, base64, io, PIL.Image as Image

class EndpointHandler:
    """
    Works for:
      • text → text chat completions
      • text → image generation    (pass {"generation_mode":"image"})
    """
    def __init__(self, model_path: str):
        self.processor = JanusProcessor.from_pretrained(
            model_path, trust_remote_code=True
        )
        self.model = JanusForConditionalGeneration.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,   # fp16 also fine
            device_map="auto",
            load_in_4bit=True             # comment out on bigger GPUs
        )

    # ---- each request lands here ----
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        prompt = data.get("prompt") or data.get("inputs")
        gen_mode = data.get("generation_mode", "text")  # "text" | "image"

        templ = self.processor.apply_chat_template(
            [{"role": "user",
              "content": [{"type": "text", "text": prompt}]}],
            add_generation_prompt=True,
        )

        inputs = self.processor(
            text=templ,
            generation_mode=gen_mode,
            return_tensors="pt"
        ).to(self.model.device)

        out = self.model.generate(
            **inputs,
            generation_mode=gen_mode,
            max_new_tokens=data.get("max_new_tokens", 128)
        )

        if gen_mode == "image":
            img = self.processor.decode(out[0], output_type="pil")
            return {"images": [self._pil_to_base64(img)]}
        else:
            return {"generated_text":
                    self.processor.decode(out[0], skip_special_tokens=True)}

    @staticmethod
    def _pil_to_base64(img: Image.Image) -> str:
        buf = io.BytesIO()
        img.save(buf, format="PNG")
        return base64.b64encode(buf.getvalue()).decode()