sreejith8100 commited on
Commit
5456c8e
·
verified ·
1 Parent(s): 30ad480

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +59 -44
handler.py CHANGED
@@ -50,60 +50,75 @@ class EndpointHandler:
50
  use_auth_token=hf_token
51
  ).eval().cuda()
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def predict(self, request):
54
  """
55
  Expected input:
56
  {
57
- "inputs": {
58
- "image": "<image URL or base64>",
59
- "msgs": [{"role": "user", "content": "<image>\nWhat is this?"}],
60
- "stream": false
61
- }
62
  }
63
  """
64
- # Extract data from the "inputs" key in the request
65
- data = request.get("inputs", {}) # Extract from "inputs" key
66
- image_input = data.get("image")
67
- msgs = data.get("msgs", [])
68
- stream = data.get("stream", False)
69
 
70
- if not image_input or not msgs:
71
- return {"error": "Missing 'image' or 'msgs'."}
72
 
73
  try:
74
- # Load image from URL or base64
75
- if image_input.startswith("http"):
76
- resp = requests.get(image_input, verify=False)
77
- image = Image.open(BytesIO(resp.content)).convert("RGB")
78
- else:
79
- image = Image.open(BytesIO(base64.b64decode(image_input))).convert("RGB")
80
- except Exception as e:
81
- return {"error": f"Invalid image format or URL: {e}"}
82
 
83
- try:
84
- # Run inference with or without streaming
85
- if stream:
86
- generated_text = ""
87
- for chunk in self.model.chat(
88
- image=image,
89
- msgs=msgs,
90
- tokenizer=self.tokenizer,
91
- sampling=True,
92
- stream=True
93
- ):
94
- generated_text += chunk
95
- return {"output": generated_text}
96
- else:
97
- output = self.model.chat(
98
- image=image,
99
- msgs=msgs,
100
- tokenizer=self.tokenizer,
101
- sampling=True,
102
- stream=False
103
- )
104
- return {"output": output}
105
- except Exception as e:
106
- return {"error": f"Inference failed: {e}"}
107
 
108
  def __call__(self, data):
109
  """
 
50
  use_auth_token=hf_token
51
  ).eval().cuda()
52
 
53
+ def load_image(self, image_input):
54
+ """
55
+ Load image from URL, base64 or file input.
56
+ """
57
+ if image_input.startswith("http"):
58
+ # Load image from URL
59
+ try:
60
+ resp = requests.get(image_input, verify=False)
61
+ image = Image.open(BytesIO(resp.content)).convert("RGB")
62
+ return image
63
+ except Exception as e:
64
+ raise ValueError(f"Failed to fetch image from URL: {e}")
65
+
66
+ elif image_input.startswith("data:image"):
67
+ # Load base64 encoded image
68
+ try:
69
+ image = Image.open(BytesIO(base64.b64decode(image_input.split(",")[1]))).convert("RGB")
70
+ return image
71
+ except Exception as e:
72
+ raise ValueError(f"Invalid base64 image format: {e}")
73
+
74
+ else:
75
+ # Load image from file
76
+ try:
77
+ image = Image.open(image_input).convert("RGB")
78
+ return image
79
+ except Exception as e:
80
+ raise ValueError(f"Failed to open image from file path: {e}")
81
+
82
  def predict(self, request):
83
  """
84
  Expected input:
85
  {
86
+ "image": "<image URL, file path, or base64>",
87
+ "question": "What is this?",
88
+ "stream": false
 
 
89
  }
90
  """
91
+ image_input = request.get("image")
92
+ question = request.get("question", "What is in the image?")
93
+ stream = request.get("stream", False)
 
 
94
 
95
+ if not image_input:
96
+ return {"error": "Missing image."}
97
 
98
  try:
99
+ # Load image using the new load_image function
100
+ image = self.load_image(image_input)
 
 
 
 
 
 
101
 
102
+ # Prepare message with <image> placeholder
103
+ msgs = [{"role": "user", "content": f"<image>\n{question}"}]
104
+
105
+ try:
106
+ if stream:
107
+ generated_text = ""
108
+ for chunk in self.model.chat(
109
+ image=image, msgs=msgs, tokenizer=self.tokenizer,
110
+ sampling=True, stream=True
111
+ ):
112
+ generated_text += chunk
113
+ return {"output": generated_text}
114
+ else:
115
+ output = self.model.chat(image=image, msgs=msgs, tokenizer=self.tokenizer)
116
+ return {"output": output}
117
+ except Exception as e:
118
+ return {"error": f"Inference failed: {e}"}
119
+
120
+ except ValueError as e:
121
+ return {"error": f"Image processing error: {e}"}
 
 
 
 
122
 
123
  def __call__(self, data):
124
  """