VOIDER commited on
Commit
ec46182
·
verified ·
1 Parent(s): 951a7f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -23,14 +23,22 @@ clip_model = CLIPModel.from_pretrained(
23
  ).to(device)
24
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
25
 
26
- # BLIP-2 for caption generation (8-bit quantized / fp8 proxy)
27
- bnb_config = BitsAndBytesConfig(load_in_8bit=True)
28
  blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
29
- blip_model = Blip2ForConditionalGeneration.from_pretrained(
30
- "Salesforce/blip2-flan-t5-xl",
31
- quantization_config=bnb_config,
32
- device_map="auto"
33
- )
 
 
 
 
 
 
 
 
 
34
 
35
  # LPIPS for diversity
36
  lpips_model = lpips.LPIPS(net='alex').to(device)
@@ -92,7 +100,6 @@ def analyze_images(files):
92
  imgs_by_model = {}
93
 
94
  for f in files:
95
- # use f.name path instead of read() to avoid NamedString issues
96
  img = Image.open(f.name).convert('RGB')
97
  prompt, model = extract_metadata(f)
98
 
@@ -114,7 +121,6 @@ def analyze_images(files):
114
 
115
  df = pd.DataFrame(records)
116
 
117
- # Diversity per model
118
  diversity = {}
119
  for model, imgs in imgs_by_model.items():
120
  if len(imgs) < 2:
 
23
  ).to(device)
24
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
25
 
26
+ # BLIP-2 for caption generation: 8-bit if GPU available, else float16
 
27
  blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
28
+ if torch.cuda.is_available():
29
+ bnb_config = BitsAndBytesConfig(load_in_8bit=True)
30
+ blip_model = Blip2ForConditionalGeneration.from_pretrained(
31
+ "Salesforce/blip2-flan-t5-xl",
32
+ quantization_config=bnb_config,
33
+ device_map="auto"
34
+ )
35
+ else:
36
+ # CPU-only environment: load half precision
37
+ blip_model = Blip2ForConditionalGeneration.from_pretrained(
38
+ "Salesforce/blip2-flan-t5-xl",
39
+ torch_dtype=torch.float16
40
+ )
41
+ blip_model.to(device)
42
 
43
  # LPIPS for diversity
44
  lpips_model = lpips.LPIPS(net='alex').to(device)
 
100
  imgs_by_model = {}
101
 
102
  for f in files:
 
103
  img = Image.open(f.name).convert('RGB')
104
  prompt, model = extract_metadata(f)
105
 
 
121
 
122
  df = pd.DataFrame(records)
123
 
 
124
  diversity = {}
125
  for model, imgs in imgs_by_model.items():
126
  if len(imgs) < 2: