VOIDER commited on
Commit
2de5535
·
verified ·
1 Parent(s): b219c60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
6
  from PIL import Image
7
  from sd_parsers import ParserManager
8
  from torchvision import transforms
9
- from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration
10
  import lpips
11
  import piq
12
  import plotly.express as px
@@ -18,31 +18,33 @@ import plotly.express as px
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
 
20
  # CLIP for prompt alignment & aesthetics
21
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
 
 
22
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
23
 
24
- # BLIP-2 for caption generation (processor without .to)
 
25
  blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
26
  blip_model = Blip2ForConditionalGeneration.from_pretrained(
27
- "Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16
28
- ).to(device)
 
 
29
 
30
- # LPIPS for diversity\ nlpips_model = lpips.LPIPS(net='alex').to(device)
 
31
 
32
  # --------------------
33
  # Helper Functions
34
  # --------------------
35
 
36
- def extract_metadata(image_bytes):
37
- """Extracts prompt and model name from image bytes using sd-parsers."""
38
  parser = ParserManager()
39
- tmp_path = "temp.png"
40
- with open(tmp_path, 'wb') as tmp:
41
- tmp.write(image_bytes)
42
- info = parser.parse(tmp_path)
43
  prompt = info.prompts[0].value if info.prompts else ''
44
  model_name = info.model_name or ''
45
- os.remove(tmp_path)
46
  return prompt, model_name
47
 
48
  # Image preprocessing transform
@@ -90,9 +92,9 @@ def analyze_images(files):
90
  imgs_by_model = {}
91
 
92
  for f in files:
93
- image_bytes = f.read()
94
- img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
95
- prompt, model = extract_metadata(image_bytes)
96
 
97
  clip_score = compute_clip_score(img, prompt)
98
  cap_sim = compute_caption_similarity(img, prompt)
 
6
  from PIL import Image
7
  from sd_parsers import ParserManager
8
  from torchvision import transforms
9
+ from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration, BitsAndBytesConfig
10
  import lpips
11
  import piq
12
  import plotly.express as px
 
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
 
20
  # CLIP for prompt alignment & aesthetics
21
+ clip_model = CLIPModel.from_pretrained(
22
+ "openai/clip-vit-base-patch32"
23
+ ).to(device)
24
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
25
 
26
+ # BLIP-2 for caption generation (8-bit quantized / fp8 proxy)
27
+ bnb_config = BitsAndBytesConfig(load_in_8bit=True)
28
  blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
29
  blip_model = Blip2ForConditionalGeneration.from_pretrained(
30
+ "Salesforce/blip2-flan-t5-xl",
31
+ quantization_config=bnb_config,
32
+ device_map="auto"
33
+ )
34
 
35
+ # LPIPS for diversity
36
+ lpips_model = lpips.LPIPS(net='alex').to(device)
37
 
38
  # --------------------
39
  # Helper Functions
40
  # --------------------
41
 
42
+ def extract_metadata(file):
43
+ """Extract prompt and model name using sd-parsers from file path."""
44
  parser = ParserManager()
45
+ info = parser.parse(file.name)
 
 
 
46
  prompt = info.prompts[0].value if info.prompts else ''
47
  model_name = info.model_name or ''
 
48
  return prompt, model_name
49
 
50
  # Image preprocessing transform
 
92
  imgs_by_model = {}
93
 
94
  for f in files:
95
+ # use f.name path instead of read() to avoid NamedString issues
96
+ img = Image.open(f.name).convert('RGB')
97
+ prompt, model = extract_metadata(f)
98
 
99
  clip_score = compute_clip_score(img, prompt)
100
  cap_sim = compute_caption_similarity(img, prompt)