Abdulmateen commited on
Commit
dc92b57
·
verified ·
1 Parent(s): 009d0db

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +37 -104
handler.py CHANGED
@@ -1,140 +1,73 @@
1
  import torch
2
- from transformers import BitsAndBytesConfig
 
3
  from PIL import Image
4
  import requests
5
  from io import BytesIO
6
  import base64
7
 
8
- # Import LLaVA's specific tools
9
- from llava.model.builder import load_pretrained_model
10
- from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
11
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
12
-
13
  class EndpointHandler:
14
  def __init__(self, path=""):
15
- # path is the local path to your LoRA adapter repository
16
-
17
- # This single, official LLaVA function handles everything:
18
- # 1. Loads the base model (llava-1.5-7b-hf)
19
- # 2. Correctly applies your LoRA adapter from the `path`
20
- # 3. Loads the correct processor and tokenizer
21
- model_name = get_model_name_from_path(path)
22
- self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
23
- model_path=path, # Path to your LoRA adapter
24
- model_base="llava-hf/llava-1.5-7b-hf", # Base model to load underneath
25
- model_name=model_name,
26
- load_in_4bit=True, # Load in 4-bit for efficiency
27
  device_map="auto"
28
  )
29
- self.model.eval()
30
- print(" Base model and LoRA adapter loaded successfully using LLaVA's official loader.")
 
 
31
 
32
  def __call__(self, data: dict) -> dict:
33
- # Get the payload from the 'inputs' key
 
 
34
  payload = data.pop("inputs", data)
35
-
 
 
36
  prompt_text = payload.pop("prompt", "Describe the image in detail.")
37
  image_url = payload.pop("image_url", None)
38
  image_b64 = payload.pop("image_b64", None)
39
  max_new_tokens = payload.pop("max_new_tokens", 200)
40
 
 
41
  if image_url:
42
  try:
43
  response = requests.get(image_url)
44
- response.raise_for_status()
45
- image = Image.open(BytesIO(response.content)).convert("RGB")
46
  except Exception as e:
47
  return {"error": f"Failed to load image from URL: {e}"}
48
  elif image_b64:
49
  try:
50
  image_bytes = base64.b64decode(image_b64)
51
- image = Image.open(BytesIO(image_bytes)).convert("RGB")
52
  except Exception as e:
53
  return {"error": f"Failed to decode base64 image: {e}"}
54
  else:
55
  return {"error": "No image provided. Please use 'image_url' or 'image_b64'."}
56
 
57
- # Process the image
58
- image_tensor = process_images([image], self.image_processor, self.model.config)
59
- image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
60
-
61
- # Format the prompt correctly for LLaVA v1.5
62
- # Note: The format is slightly different from the previous generic one
63
- conv_mode = "llava_v1"
64
- conv = Conversation(
65
- system="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.",
66
- roles=("USER", "ASSISTANT"),
67
- version="v1",
68
- messages=(),
69
- offset=0,
70
- sep_style=SeparatorStyle.TWO,
71
- sep=" ",
72
- sep2="</s>",
73
- )
74
 
75
- prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text
76
- conv.append_message(conv.roles[0], prompt)
77
- conv.append_message(conv.roles[1], None)
78
- prompt_with_template = conv.get_prompt()
79
-
80
- input_ids = tokenizer_image_token(prompt_with_template, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
81
 
82
  # Generate a response
83
- with torch.inference_mode():
84
- output_ids = self.model.generate(
85
- input_ids,
86
- images=image_tensor,
87
- do_sample=True,
88
- temperature=0.2,
89
- max_new_tokens=max_new_tokens,
90
- use_cache=True,
91
- )
92
 
93
- # Decode and return the response
94
- response_text = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
 
95
 
96
- return {"generated_text": response_text}
97
-
98
- # This Conversation class is a helper taken from LLaVA's library for correct prompt formatting
99
- import dataclasses
100
- from enum import auto, Enum
101
- from typing import List, Tuple
102
-
103
- class SeparatorStyle(Enum):
104
- SINGLE = auto()
105
- TWO = auto()
106
-
107
- @dataclasses.dataclass
108
- class Conversation:
109
- system: str
110
- roles: List[str]
111
- messages: List[List[str]]
112
- offset: int
113
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
114
- sep: str = "###"
115
- sep2: str = None
116
- version: str = "Unknown"
117
-
118
- def get_prompt(self):
119
- if self.sep_style == SeparatorStyle.SINGLE:
120
- ret = self.system + self.sep
121
- for role, message in self.messages:
122
- if message:
123
- ret += role + ": " + message + self.sep
124
- else:
125
- ret += role + ":"
126
- return ret
127
- elif self.sep_style == SeparatorStyle.TWO:
128
- seps = [self.sep, self.sep2]
129
- ret = self.system + seps[0]
130
- for i, (role, message) in enumerate(self.messages):
131
- if message:
132
- ret += role + ": " + message + seps[i % 2]
133
- else:
134
- ret += role + ":"
135
- return ret
136
- else:
137
- raise ValueError(f"Invalid style: {self.sep_style}")
138
-
139
- def append_message(self, role, message):
140
- self.messages.append([role, message])
 
1
  import torch
2
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
3
+ from peft import PeftModel
4
  from PIL import Image
5
  import requests
6
  from io import BytesIO
7
  import base64
8
 
 
 
 
 
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
+ # The 'path' argument will be the path to your LoRA repo on the Hub
12
+ base_model_id = "llava-hf/llava-1.5-7b-hf"
13
+
14
+ print("Loading processor...")
15
+ self.processor = AutoProcessor.from_pretrained(base_model_id, revision="a272c74")
16
+
17
+ print("Loading base model...")
18
+ self.model = LlavaForConditionalGeneration.from_pretrained(
19
+ base_model_id,
20
+ load_in_4bit=True,
21
+ torch_dtype=torch.float16,
 
22
  device_map="auto"
23
  )
24
+
25
+ print(f"Loading LoRA adapters from repository path: {path}...")
26
+ self.model = PeftModel.from_pretrained(self.model, path)
27
+ print("✅ Model and adapters loaded successfully.")
28
 
29
  def __call__(self, data: dict) -> dict:
30
+ # --- THIS IS THE FIX ---
31
+ # The Inference Endpoint wraps the payload in an "inputs" key.
32
+ # We must extract our data from there first.
33
  payload = data.pop("inputs", data)
34
+ # --- END OF FIX ---
35
+
36
+ # Now, get the prompt and image from the extracted payload
37
  prompt_text = payload.pop("prompt", "Describe the image in detail.")
38
  image_url = payload.pop("image_url", None)
39
  image_b64 = payload.pop("image_b64", None)
40
  max_new_tokens = payload.pop("max_new_tokens", 200)
41
 
42
+ # Load image from either a URL or a base64 string
43
  if image_url:
44
  try:
45
  response = requests.get(image_url)
46
+ response.raise_for_status() # Raise an exception for bad status codes
47
+ image = Image.open(BytesIO(response.content))
48
  except Exception as e:
49
  return {"error": f"Failed to load image from URL: {e}"}
50
  elif image_b64:
51
  try:
52
  image_bytes = base64.b64decode(image_b64)
53
+ image = Image.open(BytesIO(image_bytes))
54
  except Exception as e:
55
  return {"error": f"Failed to decode base64 image: {e}"}
56
  else:
57
  return {"error": "No image provided. Please use 'image_url' or 'image_b64'."}
58
 
59
+ # Format the prompt for LLaVA
60
+ prompt = f"USER: <image>\n{prompt_text} ASSISTANT:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # Process inputs
63
+ inputs = self.processor(text=prompt, images=image, return_tensors="pt").to("cuda")
 
 
 
 
64
 
65
  # Generate a response
66
+ with torch.no_grad():
67
+ output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
 
 
 
 
 
 
 
68
 
69
+ # Decode and clean up the response
70
+ full_response = self.processor.decode(output[0], skip_special_tokens=True)
71
+ assistant_response = full_response.split("ASSISTANT:")[-1].strip()
72
 
73
+ return {"generated_text": assistant_response}