File size: 3,264 Bytes
58f78c1
 
 
 
89a2202
58f78c1
 
 
5ee030f
89a2202
 
58f78c1
89a2202
58f78c1
 
89a2202
 
58f78c1
89a2202
58f78c1
 
5ee030f
 
 
 
 
 
2eccf7e
5ee030f
 
 
 
 
 
 
 
5a65a07
5ee030f
 
2eccf7e
5ee030f
5a65a07
5ee030f
 
 
2eccf7e
 
 
5a65a07
2eccf7e
5a65a07
 
 
2eccf7e
 
 
5a65a07
 
 
2eccf7e
5a65a07
 
2eccf7e
 
 
 
 
 
 
 
 
 
89a2202
58f78c1
2eccf7e
dd7c1fd
2eccf7e
 
58f78c1
2eccf7e
b2f905d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import base64
from io import BytesIO

class EndpointHandler:
    def __init__(self, path=""):
        # The 'path' is a self-contained directory with the complete, merged model.
        print("Loading model and processor from local path...")
        self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
        self.model = LlavaForConditionalGeneration.from_pretrained(
            path,
            load_in_4bit=True,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )
        print("✅ Model loaded successfully.")

    def __call__(self, data: dict) -> dict:
        payload = data.pop("inputs", data)
        
        prompt_text = payload.pop("prompt", "Describe the image in detail.")
        image_b64 = payload.pop("image_b64", None)
        max_new_tokens = payload.pop("max_new_tokens", 200)

        image = None
        if image_b64:
            try:
                image_bytes = base64.b64decode(image_b64)
                image = Image.open(BytesIO(image_bytes))
            except Exception as e:
                return {"error": f"Failed to decode or open base64 image: {e}"}

        if image is not None:
            # --- Case 1: Multimodal (Image + Text) ---
            print("Processing multimodal request...")
            prompt = f"USER: <image>\n{prompt_text} ASSISTANT:"
            inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.model.device)
        else:
            # --- Case 2: Text-Only - CORRECTED LOGIC ---
            print("Processing text-only request...")
            prompt = f"USER: {prompt_text} ASSISTANT:"

            # First, process the text to get input_ids
            inputs = self.processor(text=prompt, return_tensors="pt")
            
            # --- THE FIX: Get image dimensions from the processor's .config ---
            image_processor = self.processor.image_processor
            config = image_processor.config
            
            # Create a dummy image tensor using the correct config values
            dummy_pixel_values = torch.zeros(
                (
                    1, 
                    config.num_channels, 
                    config.crop_size['height'], 
                    config.crop_size['width']
                ),
                dtype=self.model.dtype,
                device=self.model.device
            )
            
            # Add the dummy tensor to the inputs dictionary
            inputs['pixel_values'] = dummy_pixel_values
            
            # Ensure the entire input dictionary is on the correct device
            inputs = inputs.to(self.model.device)


        # Generate the output (this part is the same for both cases)
        with torch.no_grad():
            output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
                
        full_response = self.processor.decode(output[0], skip_special_tokens=True)
        
        # Clean up the response string
        assistant_response = full_response.split("ASSISTANT:")[-1].strip()
        
        return {"generated_text": assistant_response}