rishabh-zuma commited on
Commit
7ab9f15
·
1 Parent(s): 2ae7815

Moved to 16 bit precision

Browse files
Files changed (1) hide show
  1. handler.py +2 -2
handler.py CHANGED
@@ -22,7 +22,7 @@ class EndpointHandler():
22
 
23
  print(" $$$$ Model Loading $$$$")
24
  self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
25
- self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="auto")
26
  print(" $$$$ model loaded $$$$")
27
  self.model.eval()
28
  self.model = self.model.to(device)
@@ -66,7 +66,7 @@ class EndpointHandler():
66
  # img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
67
  # raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
68
 
69
- inputs = self.processor(raw_image, prompt, return_tensors="pt").to("cuda")
70
 
71
  out = self.model.generate(**inputs)
72
  captions = processor.decode(out[0], skip_special_tokens=True)
 
22
 
23
  print(" $$$$ Model Loading $$$$")
24
  self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
25
+ self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
26
  print(" $$$$ model loaded $$$$")
27
  self.model.eval()
28
  self.model = self.model.to(device)
 
66
  # img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
67
  # raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
68
 
69
+ inputs = self.processor(raw_image, prompt, return_tensors="pt").to("cuda", torch.float16)
70
 
71
  out = self.model.generate(**inputs)
72
  captions = processor.decode(out[0], skip_special_tokens=True)