Abdulmateen commited on
Commit
009d0db
·
verified ·
1 Parent(s): 8ca6669

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +105 -40
handler.py CHANGED
@@ -1,75 +1,140 @@
1
  import torch
2
- from transformers import AutoProcessor, LlavaForConditionalGeneration
3
- from peft import PeftModel
4
  from PIL import Image
5
  import requests
6
  from io import BytesIO
7
  import base64
8
 
 
 
 
 
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
- # The 'path' argument will be the path to your LoRA repo on the Hub
12
- base_model_id = "llava-hf/llava-1.5-7b-hf"
13
-
14
- print("Loading processor...")
15
- self.processor = AutoProcessor.from_pretrained(base_model_id, revision="a272c74")
16
-
17
- print("Loading base model...")
18
- self.model = LlavaForConditionalGeneration.from_pretrained(
19
- base_model_id,
20
- load_in_4bit=True,
21
- torch_dtype=torch.float16,
 
22
  device_map="auto"
23
  )
24
-
25
- print(f"Loading LoRA adapters from repository path: {path}...")
26
- self.model = PeftModel.from_pretrained(self.model, path)
27
- print("✅ Model and adapters loaded successfully.")
28
  def __call__(self, data: dict) -> dict:
 
29
  payload = data.pop("inputs", data)
30
-
31
- # Optional dynamic override of LoRA adapters (costly if done per request)
32
- lora_override = payload.pop("lora_path", None)
33
- if lora_override:
34
- # reload adapters on the fly (might be slow)
35
- self.model = PeftModel.from_pretrained(self.model.base_model, lora_override)
36
-
37
-
38
- # Now, get the prompt and image from the extracted payload
39
  prompt_text = payload.pop("prompt", "Describe the image in detail.")
40
  image_url = payload.pop("image_url", None)
41
  image_b64 = payload.pop("image_b64", None)
42
  max_new_tokens = payload.pop("max_new_tokens", 200)
43
 
44
- # Load image from either a URL or a base64 string
45
  if image_url:
46
  try:
47
  response = requests.get(image_url)
48
- response.raise_for_status() # Raise an exception for bad status codes
49
- image = Image.open(BytesIO(response.content))
50
  except Exception as e:
51
  return {"error": f"Failed to load image from URL: {e}"}
52
  elif image_b64:
53
  try:
54
  image_bytes = base64.b64decode(image_b64)
55
- image = Image.open(BytesIO(image_bytes))
56
  except Exception as e:
57
  return {"error": f"Failed to decode base64 image: {e}"}
58
  else:
59
  return {"error": "No image provided. Please use 'image_url' or 'image_b64'."}
60
 
61
- # Format the prompt for LLaVA
62
- prompt = f"USER: <image>\n{prompt_text} ASSISTANT:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- # Process inputs
65
- inputs = self.processor(text=prompt, images=image, return_tensors="pt").to("cuda")
66
 
67
  # Generate a response
68
- with torch.no_grad():
69
- output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
 
 
 
 
 
 
 
70
 
71
- # Decode and clean up the response
72
- full_response = self.processor.decode(output[0], skip_special_tokens=True)
73
- assistant_response = full_response.split("ASSISTANT:")[-1].strip()
74
 
75
- return {"generated_text": assistant_response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import BitsAndBytesConfig
 
3
  from PIL import Image
4
  import requests
5
  from io import BytesIO
6
  import base64
7
 
8
+ # Import LLaVA's specific tools
9
+ from llava.model.builder import load_pretrained_model
10
+ from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
11
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
12
+
13
  class EndpointHandler:
14
  def __init__(self, path=""):
15
+ # path is the local path to your LoRA adapter repository
16
+
17
+ # This single, official LLaVA function handles everything:
18
+ # 1. Loads the base model (llava-1.5-7b-hf)
19
+ # 2. Correctly applies your LoRA adapter from the `path`
20
+ # 3. Loads the correct processor and tokenizer
21
+ model_name = get_model_name_from_path(path)
22
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
23
+ model_path=path, # Path to your LoRA adapter
24
+ model_base="llava-hf/llava-1.5-7b-hf", # Base model to load underneath
25
+ model_name=model_name,
26
+ load_in_4bit=True, # Load in 4-bit for efficiency
27
  device_map="auto"
28
  )
29
+ self.model.eval()
30
+ print(" Base model and LoRA adapter loaded successfully using LLaVA's official loader.")
31
+
 
32
  def __call__(self, data: dict) -> dict:
33
+ # Get the payload from the 'inputs' key
34
  payload = data.pop("inputs", data)
35
+
 
 
 
 
 
 
 
 
36
  prompt_text = payload.pop("prompt", "Describe the image in detail.")
37
  image_url = payload.pop("image_url", None)
38
  image_b64 = payload.pop("image_b64", None)
39
  max_new_tokens = payload.pop("max_new_tokens", 200)
40
 
 
41
  if image_url:
42
  try:
43
  response = requests.get(image_url)
44
+ response.raise_for_status()
45
+ image = Image.open(BytesIO(response.content)).convert("RGB")
46
  except Exception as e:
47
  return {"error": f"Failed to load image from URL: {e}"}
48
  elif image_b64:
49
  try:
50
  image_bytes = base64.b64decode(image_b64)
51
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
52
  except Exception as e:
53
  return {"error": f"Failed to decode base64 image: {e}"}
54
  else:
55
  return {"error": "No image provided. Please use 'image_url' or 'image_b64'."}
56
 
57
+ # Process the image
58
+ image_tensor = process_images([image], self.image_processor, self.model.config)
59
+ image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
60
+
61
+ # Format the prompt correctly for LLaVA v1.5
62
+ # Note: The format is slightly different from the previous generic one
63
+ conv_mode = "llava_v1"
64
+ conv = Conversation(
65
+ system="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.",
66
+ roles=("USER", "ASSISTANT"),
67
+ version="v1",
68
+ messages=(),
69
+ offset=0,
70
+ sep_style=SeparatorStyle.TWO,
71
+ sep=" ",
72
+ sep2="</s>",
73
+ )
74
+
75
+ prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text
76
+ conv.append_message(conv.roles[0], prompt)
77
+ conv.append_message(conv.roles[1], None)
78
+ prompt_with_template = conv.get_prompt()
79
 
80
+ input_ids = tokenizer_image_token(prompt_with_template, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
 
81
 
82
  # Generate a response
83
+ with torch.inference_mode():
84
+ output_ids = self.model.generate(
85
+ input_ids,
86
+ images=image_tensor,
87
+ do_sample=True,
88
+ temperature=0.2,
89
+ max_new_tokens=max_new_tokens,
90
+ use_cache=True,
91
+ )
92
 
93
+ # Decode and return the response
94
+ response_text = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
 
95
 
96
+ return {"generated_text": response_text}
97
+
98
+ # This Conversation class is a helper taken from LLaVA's library for correct prompt formatting
99
+ import dataclasses
100
+ from enum import auto, Enum
101
+ from typing import List, Tuple
102
+
103
+ class SeparatorStyle(Enum):
104
+ SINGLE = auto()
105
+ TWO = auto()
106
+
107
+ @dataclasses.dataclass
108
+ class Conversation:
109
+ system: str
110
+ roles: List[str]
111
+ messages: List[List[str]]
112
+ offset: int
113
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
114
+ sep: str = "###"
115
+ sep2: str = None
116
+ version: str = "Unknown"
117
+
118
+ def get_prompt(self):
119
+ if self.sep_style == SeparatorStyle.SINGLE:
120
+ ret = self.system + self.sep
121
+ for role, message in self.messages:
122
+ if message:
123
+ ret += role + ": " + message + self.sep
124
+ else:
125
+ ret += role + ":"
126
+ return ret
127
+ elif self.sep_style == SeparatorStyle.TWO:
128
+ seps = [self.sep, self.sep2]
129
+ ret = self.system + seps[0]
130
+ for i, (role, message) in enumerate(self.messages):
131
+ if message:
132
+ ret += role + ": " + message + seps[i % 2]
133
+ else:
134
+ ret += role + ":"
135
+ return ret
136
+ else:
137
+ raise ValueError(f"Invalid style: {self.sep_style}")
138
+
139
+ def append_message(self, role, message):
140
+ self.messages.append([role, message])