sreejith8100 commited on
Commit
856dd67
·
verified ·
1 Parent(s): b6f80b8

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +52 -22
handler.py CHANGED
@@ -6,26 +6,40 @@ from io import BytesIO
6
  import base64
7
  import ssl
8
  import urllib3
9
-
10
  urllib3.disable_warnings()
11
  ssl._create_default_https_context = ssl._create_unverified_context
12
 
13
- class EndpointHandler:
14
- def __init__(self, path=""):
15
- model_name = "openbmb/MiniCPM-V-2_6-int4"
 
 
 
 
16
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
17
  self.model = AutoModel.from_pretrained(
18
  model_name,
19
  trust_remote_code=True,
20
- # Explicitly disable 4-bit loading
21
- device_map="auto"
22
- ).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- def __call__(self, data):
25
- image_input = data.get("image")
26
- question = data.get("question", "What is in this image?")
27
  if not image_input:
28
- return {"error": "Image is required."}
29
 
30
  try:
31
  if image_input.startswith("http"):
@@ -36,21 +50,37 @@ class EndpointHandler:
36
  except Exception as e:
37
  return {"error": f"Failed to load image: {e}"}
38
 
39
- msgs = [{"role": "user", "content": question}]
40
- result_text = ""
41
 
42
  try:
43
- with torch.no_grad():
44
- for chunk in self.model.chat(
45
- image=image,
 
46
  msgs=msgs,
47
  tokenizer=self.tokenizer,
48
- stream=True,
49
- max_new_tokens=128,
50
- temperature=0.3
51
  ):
52
- result_text += chunk
 
 
 
 
 
 
 
 
53
  except Exception as e:
54
- return {"error": f"Model inference failed: {e}"}
 
55
 
56
- return {"output": result_text}
 
 
 
 
 
 
 
 
 
6
  import base64
7
  import ssl
8
  import urllib3
 
9
  urllib3.disable_warnings()
10
  ssl._create_default_https_context = ssl._create_unverified_context
11
 
12
+ class ModelHandler:
13
+ def __init__(self):
14
+ self.model = None
15
+ self.tokenizer = None
16
+
17
+ def load_model(self):
18
+ model_name = "openbmb/MiniCPM-V-2_6"
19
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
20
+
21
  self.model = AutoModel.from_pretrained(
22
  model_name,
23
  trust_remote_code=True,
24
+ attn_implementation="sdpa",
25
+ torch_dtype=torch.bfloat16
26
+ ).eval().cuda()
27
+
28
+ def predict(self, request):
29
+ """
30
+ Expected request format:
31
+ {
32
+ "image": "<url or base64 string>",
33
+ "question": "What is shown in the image?",
34
+ "stream": false (optional)
35
+ }
36
+ """
37
+ image_input = request.get("image")
38
+ question = request.get("question", "What is in the image?")
39
+ stream = request.get("stream", False)
40
 
 
 
 
41
  if not image_input:
42
+ return {"error": "Image input is required."}
43
 
44
  try:
45
  if image_input.startswith("http"):
 
50
  except Exception as e:
51
  return {"error": f"Failed to load image: {e}"}
52
 
53
+ msgs = [{"role": "user", "content": [image, question]}]
 
54
 
55
  try:
56
+ if stream:
57
+ generated_text = ""
58
+ for new_text in self.model.chat(
59
+ image=None,
60
  msgs=msgs,
61
  tokenizer=self.tokenizer,
62
+ sampling=True,
63
+ stream=True
 
64
  ):
65
+ generated_text += new_text
66
+ return {"output": generated_text}
67
+ else:
68
+ output = self.model.chat(
69
+ image=None,
70
+ msgs=msgs,
71
+ tokenizer=self.tokenizer
72
+ )
73
+ return {"output": output}
74
  except Exception as e:
75
+ return {"error": f"Inference failed: {e}"}
76
+
77
 
78
+ # Test block (optional, remove in production)
79
+ if __name__ == "__main__":
80
+ handler = ModelHandler()
81
+ handler.load_model()
82
+ result = handler.predict({
83
+ "image": "https://upload.wikimedia.org/wikipedia/commons/9/9e/Ours_brun_parcanimalierpyrenees_1.jpg",
84
+ "question": "What animal is this?"
85
+ })
86
+ print(result)