rishabh-zuma commited on
Commit
db1d42e
·
1 Parent(s): 98f4fce

rollback to previous code

Browse files
Files changed (1) hide show
  1. handler.py +9 -8
handler.py CHANGED
@@ -21,9 +21,10 @@ class EndpointHandler():
21
 
22
 
23
  print(" $$$$ Model Loading $$$$")
24
- self.processor = Blip2Processor.from_pretrained("blip2/sharded")
25
- self.model = Blip2ForConditionalGeneration.from_pretrained("blip2/sharded", device_map = "auto", load_in_8bit = True)
26
  print(" $$$$ model loaded $$$$")
 
27
 
28
 
29
 
@@ -63,13 +64,13 @@ class EndpointHandler():
63
  # img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
64
  # raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
65
 
66
- generated_ids = self.processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
67
- generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
68
  print("@@@@@@ generated_text @@@@@@@")
69
  print(generated_text)
 
 
70
 
 
71
 
72
- # out = self.model.generate(**inputs)
73
- # captions = processor.decode(out[0], skip_special_tokens=True)
74
-
75
- return {"captions": generated_text}
 
21
 
22
 
23
  print(" $$$$ Model Loading $$$$")
24
+ self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
25
+ self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
26
  print(" $$$$ model loaded $$$$")
27
+ print(self.model.eval())
28
 
29
 
30
 
 
64
  # img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
65
  # raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
66
 
67
+ inputs = self.processor(raw_image, prompt, return_tensors="pt").to("cuda", torch.float16)
68
+
69
  print("@@@@@@ generated_text @@@@@@@")
70
  print(generated_text)
71
+ out = self.model.generate(**inputs)
72
+ captions = processor.decode(out[0], skip_special_tokens=True)
73
 
74
+ print("captions", captions)
75
 
76
+ return {"captions": captions}