Update app.py
Browse files
app.py
CHANGED
|
@@ -74,15 +74,15 @@ import torch
|
|
| 74 |
# ----------------------
|
| 75 |
# Load BLIP2 for Captioning
|
| 76 |
# ----------------------
|
| 77 |
-
caption_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-
|
| 78 |
-
caption_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-
|
| 79 |
|
| 80 |
# ----------------------
|
| 81 |
# Load BLIP2 for VQA
|
| 82 |
# ----------------------
|
| 83 |
-
vqa_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-
|
| 84 |
vqa_model = Blip2ForConditionalGeneration.from_pretrained(
|
| 85 |
-
|
| 86 |
)
|
| 87 |
|
| 88 |
# ----------------------
|
|
|
|
| 74 |
# ----------------------
|
| 75 |
# Load BLIP2 for Captioning
|
| 76 |
# ----------------------
|
| 77 |
+
caption_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-125m")
|
| 78 |
+
caption_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-125m")
|
| 79 |
|
| 80 |
# ----------------------
|
| 81 |
# Load BLIP2 for VQA
|
| 82 |
# ----------------------
|
| 83 |
+
vqa_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-base")
|
| 84 |
vqa_model = Blip2ForConditionalGeneration.from_pretrained(
|
| 85 |
+
"Salesforce/blip2-flan-t5-base", torch_dtype=torch.float16, device_map="auto"
|
| 86 |
)
|
| 87 |
|
| 88 |
# ----------------------
|