Abdulmateen commited on
Commit
b2f905d
·
verified ·
1 Parent(s): 58f78c1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +23 -24
handler.py CHANGED
@@ -7,33 +7,27 @@ import base64
7
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
- # The 'path' argument is now the path to your single, merged model repository
11
-
12
  print(f"Loading processor and model from: {path}...")
13
-
14
- # --- SIMPLIFIED LOADING ---
15
- # No more base model or PeftModel. Load everything directly.
16
- self.processor = AutoProcessor.from_pretrained(path, revision="a272c74")
17
  self.model = LlavaForConditionalGeneration.from_pretrained(
18
  path,
19
  load_in_4bit=True,
20
  torch_dtype=torch.float16,
21
  device_map="auto"
22
  )
23
- # --- END OF SIMPLIFICATION ---
24
-
25
  print("✅ Model loaded successfully.")
26
 
27
  def __call__(self, data: dict) -> dict:
28
- # The inference logic remains the same
29
  payload = data.pop("inputs", data)
30
-
31
  prompt_text = payload.pop("prompt", "Describe the image in detail.")
32
  image_url = payload.pop("image_url", None)
33
  image_b64 = payload.pop("image_b64", None)
34
  max_new_tokens = payload.pop("max_new_tokens", 200)
35
 
36
- # Load image from either a URL or a base64 string
 
37
  if image_url:
38
  try:
39
  response = requests.get(image_url)
@@ -47,21 +41,26 @@ class EndpointHandler:
47
  image = Image.open(BytesIO(image_bytes))
48
  except Exception as e:
49
  return {"error": f"Failed to decode base64 image: {e}"}
50
- else:
51
- return {"error": "No image provided. Please use 'image_url' or 'image_b64'."}
52
-
53
- # Format the prompt for LLaVA
54
- prompt = f"USER: <image>\n{prompt_text} ASSISTANT:"
55
 
56
- # Process inputs
57
- inputs = self.processor(text=prompt, images=image, return_tensors="pt").to("cuda")
 
 
 
 
 
 
58
 
59
- # Generate a response
60
- with torch.no_grad():
 
 
 
 
 
61
  output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
 
62
 
63
- # Decode and clean up the response
64
- full_response = self.processor.decode(output[0], skip_special_tokens=True)
65
  assistant_response = full_response.split("ASSISTANT:")[-1].strip()
66
-
67
- return {"generated_text": assistant_response}
 
7
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
+ # This part remains the same
 
11
  print(f"Loading processor and model from: {path}...")
12
+ self.processor = AutoProcessor.from_pretrained(path) # Removed revision for broader compatibility
 
 
 
13
  self.model = LlavaForConditionalGeneration.from_pretrained(
14
  path,
15
  load_in_4bit=True,
16
  torch_dtype=torch.float16,
17
  device_map="auto"
18
  )
 
 
19
  print("✅ Model loaded successfully.")
20
 
21
  def __call__(self, data: dict) -> dict:
 
22
  payload = data.pop("inputs", data)
23
+
24
  prompt_text = payload.pop("prompt", "Describe the image in detail.")
25
  image_url = payload.pop("image_url", None)
26
  image_b64 = payload.pop("image_b64", None)
27
  max_new_tokens = payload.pop("max_new_tokens", 200)
28
 
29
+ image = None
30
+ # Try to load an image if provided
31
  if image_url:
32
  try:
33
  response = requests.get(image_url)
 
41
  image = Image.open(BytesIO(image_bytes))
42
  except Exception as e:
43
  return {"error": f"Failed to decode base64 image: {e}"}
 
 
 
 
 
44
 
45
+ # --- NEW LOGIC: Check if an image is present ---
46
+ if image is not None:
47
+ # --- Case 1: Multimodal (Image + Text) ---
48
+ print("Processing multimodal request...")
49
+ prompt = f"USER: <image>\n{prompt_text} ASSISTANT:"
50
+ inputs = self.processor(text=prompt, images=image, return_tensors="pt").to("cuda")
51
+ output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
52
+ full_response = self.processor.decode(output[0], skip_special_tokens=True)
53
 
54
+ else:
55
+ # --- Case 2: Text-Only ---
56
+ print("Processing text-only request...")
57
+ prompt = f"USER: {prompt_text} ASSISTANT:"
58
+ # Note: We do NOT pass the 'images' argument here
59
+ inputs = self.processor(text=prompt, return_tensors="pt").to("cuda")
60
+ # Note: We do NOT pass the 'images' keyword to generate()
61
  output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
62
+ full_response = self.processor.decode(output[0], skip_special_tokens=True)
63
 
64
+ # Clean up the response to get only the assistant's part
 
65
  assistant_response = full_response.split("ASSISTANT:")[-1].strip()
66
+ return {"generated_text": assistant_response}