sreejith8100 commited on
Commit
99d1be1
·
verified ·
1 Parent(s): 030c1d8

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +35 -81
handler.py CHANGED
@@ -1,76 +1,28 @@
1
  import torch
2
  from PIL import Image
3
  from transformers import AutoModel, AutoTokenizer
4
- import requests
5
- from io import BytesIO
6
  import base64
7
- import ssl
8
- import urllib3
9
- import os
10
-
11
- # Check if CUDA is available
12
- print("CUDA Available:", torch.cuda.is_available())
13
-
14
- if torch.cuda.is_available():
15
- device_name = torch.cuda.get_device_name(torch.cuda.current_device())
16
- print(f"CUDA Device: {device_name}")
17
- print(f"Memory Allocated: {torch.cuda.memory_allocated()} bytes")
18
- print(f"Memory Cached: {torch.cuda.memory_reserved()} bytes")
19
- print(f"PyTorch Version: {torch.__version__}")
20
- print(f"CUDA Version (PyTorch uses): {torch.version.cuda}")
21
- else:
22
- print("CUDA is not available!")
23
-
24
- urllib3.disable_warnings()
25
- ssl._create_default_https_context = ssl._create_unverified_context
26
 
27
  class EndpointHandler:
28
  def __init__(self, model_dir=None):
29
  self.load_model()
30
 
31
  def load_model(self):
32
- model_name = "openbmb/MiniCPM-V-2_6"
33
- hf_token = os.getenv("HF_AUTH_TOKEN")
34
-
35
- self.tokenizer = AutoTokenizer.from_pretrained(
36
- model_name, trust_remote_code=True, use_auth_token=hf_token
37
- )
38
- self.model = AutoModel.from_pretrained(
39
- model_name,
40
- trust_remote_code=True,
41
- attn_implementation="sdpa",
42
- torch_dtype=torch.float16,
43
- use_auth_token=hf_token
44
- ).eval().cuda()
45
 
46
- def load_image(self, image_input):
47
- if image_input.startswith("http"):
48
- try:
49
- resp = requests.get(image_input, verify=False)
50
- image = Image.open(BytesIO(resp.content)).convert("RGB")
51
- return image
52
- except Exception as e:
53
- raise ValueError(f"Failed to fetch image from URL: {e}")
54
-
55
- elif image_input.startswith("data:image"):
56
- try:
57
- image = Image.open(BytesIO(base64.b64decode(image_input.split(",")[1]))).convert("RGB")
58
- return image
59
- except Exception as e:
60
- raise ValueError(f"Invalid base64 image format: {e}")
61
-
62
- else:
63
- try:
64
- image = Image.open(image_input).convert("RGB")
65
- return image
66
- except Exception as e:
67
- raise ValueError(f"Failed to open image from file path: {e}")
68
 
69
  def predict(self, request):
70
- # Unwrap Hugging Face format
71
- if "inputs" in request:
72
- request = request["inputs"]
73
-
74
  image_input = request.get("image")
75
  question = request.get("question", "What is in the image?")
76
  stream = request.get("stream", False)
@@ -80,28 +32,30 @@ class EndpointHandler:
80
 
81
  try:
82
  image = self.load_image(image_input)
83
- msgs = [{"role": "user", "content": f"<image>\n{question}"}]
84
-
85
- try:
86
- if stream:
87
- generated_text = ""
88
- for chunk in self.model.chat(
89
- image=None, msgs=msgs, tokenizer=self.tokenizer,
90
- sampling=True, stream=True
91
- ):
92
- generated_text += chunk
93
- return {"output": generated_text}
94
- else:
95
- output = self.model.chat(image=None, msgs=msgs, tokenizer=self.tokenizer)
96
- return {"output": output}
97
- except Exception as e:
98
- return {"error": f"Inference failed: {e}"}
99
-
100
- except ValueError as e:
101
- return {"error": f"Image processing error: {e}"}
 
 
102
 
103
  def __call__(self, data):
104
  return self.predict(data)
105
 
106
- # Hugging Face looks for a callable handler
107
- handler = EndpointHandler()
 
1
  import torch
2
  from PIL import Image
3
  from transformers import AutoModel, AutoTokenizer
 
 
4
  import base64
5
+ from io import BytesIO
6
+ import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class EndpointHandler:
9
  def __init__(self, model_dir=None):
10
  self.load_model()
11
 
12
  def load_model(self):
13
+ model_name = "openbmb/MiniCPM-V-2_6-int4"
14
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
15
+ self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16)
16
+ self.model.eval()
 
 
 
 
 
 
 
 
 
17
 
18
+ def load_image(self, image_bytes):
19
+ try:
20
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
21
+ return image
22
+ except Exception as e:
23
+ raise ValueError(f"Failed to open image from bytes: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def predict(self, request):
 
 
 
 
26
  image_input = request.get("image")
27
  question = request.get("question", "What is in the image?")
28
  stream = request.get("stream", False)
 
32
 
33
  try:
34
  image = self.load_image(image_input)
35
+ msgs = [{"role": "user", "content": [image, question]}]
36
+
37
+ if stream:
38
+ generated_text = ""
39
+ res = self.model.chat(
40
+ image=None,
41
+ msgs=msgs,
42
+ tokenizer=self.tokenizer,
43
+ sampling=True,
44
+ temperature=0.7,
45
+ stream=True
46
+ )
47
+ for new_text in res:
48
+ generated_text += new_text
49
+ return {"output": generated_text}
50
+ else:
51
+ output = self.model.chat(image=None, msgs=msgs, tokenizer=self.tokenizer)
52
+ return {"output": output}
53
+
54
+ except Exception as e:
55
+ return {"error": str(e)}
56
 
57
  def __call__(self, data):
58
  return self.predict(data)
59
 
60
+ # Example usage
61
+ handler = EndpointHandler()