Update app.py
Browse files
app.py
CHANGED
|
@@ -23,14 +23,22 @@ clip_model = CLIPModel.from_pretrained(
|
|
| 23 |
).to(device)
|
| 24 |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 25 |
|
| 26 |
-
# BLIP-2 for caption generation
|
| 27 |
-
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 28 |
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# LPIPS for diversity
|
| 36 |
lpips_model = lpips.LPIPS(net='alex').to(device)
|
|
@@ -92,7 +100,6 @@ def analyze_images(files):
|
|
| 92 |
imgs_by_model = {}
|
| 93 |
|
| 94 |
for f in files:
|
| 95 |
-
# use f.name path instead of read() to avoid NamedString issues
|
| 96 |
img = Image.open(f.name).convert('RGB')
|
| 97 |
prompt, model = extract_metadata(f)
|
| 98 |
|
|
@@ -114,7 +121,6 @@ def analyze_images(files):
|
|
| 114 |
|
| 115 |
df = pd.DataFrame(records)
|
| 116 |
|
| 117 |
-
# Diversity per model
|
| 118 |
diversity = {}
|
| 119 |
for model, imgs in imgs_by_model.items():
|
| 120 |
if len(imgs) < 2:
|
|
|
|
| 23 |
).to(device)
|
| 24 |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 25 |
|
| 26 |
+
# BLIP-2 for caption generation: 8-bit if GPU available, else float16
|
|
|
|
| 27 |
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
| 28 |
+
if torch.cuda.is_available():
|
| 29 |
+
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 30 |
+
blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
| 31 |
+
"Salesforce/blip2-flan-t5-xl",
|
| 32 |
+
quantization_config=bnb_config,
|
| 33 |
+
device_map="auto"
|
| 34 |
+
)
|
| 35 |
+
else:
|
| 36 |
+
# CPU-only environment: load half precision
|
| 37 |
+
blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
| 38 |
+
"Salesforce/blip2-flan-t5-xl",
|
| 39 |
+
torch_dtype=torch.float16
|
| 40 |
+
)
|
| 41 |
+
blip_model.to(device)
|
| 42 |
|
| 43 |
# LPIPS for diversity
|
| 44 |
lpips_model = lpips.LPIPS(net='alex').to(device)
|
|
|
|
| 100 |
imgs_by_model = {}
|
| 101 |
|
| 102 |
for f in files:
|
|
|
|
| 103 |
img = Image.open(f.name).convert('RGB')
|
| 104 |
prompt, model = extract_metadata(f)
|
| 105 |
|
|
|
|
| 121 |
|
| 122 |
df = pd.DataFrame(records)
|
| 123 |
|
|
|
|
| 124 |
diversity = {}
|
| 125 |
for model, imgs in imgs_by_model.items():
|
| 126 |
if len(imgs) < 2:
|