wealthcoders commited on
Commit
16ceee2
·
verified ·
1 Parent(s): 248decf

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -9
handler.py CHANGED
@@ -3,26 +3,27 @@ from typing import Dict, List, Any
3
  import torch
4
 
5
  class EndpointHandler:
 
 
 
 
 
 
6
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
7
- model = Qwen3VLForConditionalGeneration.from_pretrained(
8
- "Qwen/Qwen3-VL-8B-Instruct",
9
- device_map="auto" # Automatically uses available GPUs
10
- )
11
- processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
12
 
13
  # Prepare your messages with image and text
14
  messages = data.get("messages")
15
 
16
  # Process the input and generate a response
17
- inputs = processor.apply_chat_template(
18
  messages=messages,
19
  tokenize=True,
20
  add_generation_prompt=True,
21
  return_dict=True,
22
  return_tensors="pt"
23
  )
24
- inputs = inputs.to(model.device)
25
 
26
- generated_ids = model.generate(**inputs, max_new_tokens=128)
27
- output_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
28
  return output_text[0]
 
3
  import torch
4
 
5
  class EndpointHandler:
6
+ def __init__(self, path: str = "Qwen/Qwen3-VL-8B-Instruct"):
7
+ # Load tokenizer and model
8
+ self.processor = AutoProcessor.from_pretrained(path)
9
+ self.model = Qwen3VLForConditionalGeneration.from_pretrained(path, device_map="auto")
10
+ self.model.eval()
11
+
12
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
 
 
 
 
13
 
14
  # Prepare your messages with image and text
15
  messages = data.get("messages")
16
 
17
  # Process the input and generate a response
18
+ inputs = self.processor.apply_chat_template(
19
  messages=messages,
20
  tokenize=True,
21
  add_generation_prompt=True,
22
  return_dict=True,
23
  return_tensors="pt"
24
  )
25
+ inputs = inputs.to(self.model.device)
26
 
27
+ generated_ids = self.model.generate(**inputs, max_new_tokens=128)
28
+ output_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
29
  return output_text[0]