VOIDER commited on
Commit
b219c60
·
verified ·
1 Parent(s): 57f8ee6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -14,20 +14,20 @@ import plotly.express as px
14
  # --------------------
15
  # Setup Models
16
  # --------------------
 
17
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
 
19
  # CLIP for prompt alignment & aesthetics
20
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
21
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
22
 
23
- # BLIP-2 for caption generation
24
- blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)
25
  blip_model = Blip2ForConditionalGeneration.from_pretrained(
26
  "Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16
27
  ).to(device)
28
 
29
- # LPIPS for diversity
30
- lpips_model = lpips.LPIPS(net='alex').to(device)
31
 
32
  # --------------------
33
  # Helper Functions
@@ -36,20 +36,23 @@ lpips_model = lpips.LPIPS(net='alex').to(device)
36
  def extract_metadata(image_bytes):
37
  """Extracts prompt and model name from image bytes using sd-parsers."""
38
  parser = ParserManager()
39
- with open('temp.png', 'wb') as tmp:
 
40
  tmp.write(image_bytes)
41
- info = parser.parse('temp.png')
42
  prompt = info.prompts[0].value if info.prompts else ''
43
  model_name = info.model_name or ''
44
- os.remove('temp.png')
45
  return prompt, model_name
46
 
47
- # Image preprocessing for models
48
  preprocess = transforms.Compose([
49
  transforms.Resize((224, 224)),
50
  transforms.ToTensor(),
51
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
52
- (0.26862954, 0.26130258, 0.27577711))
 
 
53
  ])
54
 
55
  def compute_clip_score(img: Image.Image, text: str) -> float:
@@ -60,7 +63,7 @@ def compute_clip_score(img: Image.Image, text: str) -> float:
60
 
61
  @torch.no_grad()
62
  def compute_caption_similarity(img: Image.Image, prompt: str) -> float:
63
- inputs = blip_processor(images=img, return_tensors="pt").to(device, torch.float16)
64
  out = blip_model.generate(**inputs)
65
  caption = blip_processor.decode(out[0], skip_special_tokens=True)
66
  return compute_clip_score(img, caption)
 
14
  # --------------------
15
  # Setup Models
16
  # --------------------
17
+
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
 
20
  # CLIP for prompt alignment & aesthetics
21
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
22
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
23
 
24
+ # BLIP-2 for caption generation (processor without .to)
25
+ blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
26
  blip_model = Blip2ForConditionalGeneration.from_pretrained(
27
  "Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16
28
  ).to(device)
29
 
30
+ # LPIPS for diversity\ nlpips_model = lpips.LPIPS(net='alex').to(device)
 
31
 
32
  # --------------------
33
  # Helper Functions
 
36
  def extract_metadata(image_bytes):
37
  """Extracts prompt and model name from image bytes using sd-parsers."""
38
  parser = ParserManager()
39
+ tmp_path = "temp.png"
40
+ with open(tmp_path, 'wb') as tmp:
41
  tmp.write(image_bytes)
42
+ info = parser.parse(tmp_path)
43
  prompt = info.prompts[0].value if info.prompts else ''
44
  model_name = info.model_name or ''
45
+ os.remove(tmp_path)
46
  return prompt, model_name
47
 
48
+ # Image preprocessing transform
49
  preprocess = transforms.Compose([
50
  transforms.Resize((224, 224)),
51
  transforms.ToTensor(),
52
+ transforms.Normalize(
53
+ (0.48145466, 0.4578275, 0.40821073),
54
+ (0.26862954, 0.26130258, 0.27577711)
55
+ )
56
  ])
57
 
58
  def compute_clip_score(img: Image.Image, text: str) -> float:
 
63
 
64
  @torch.no_grad()
65
  def compute_caption_similarity(img: Image.Image, prompt: str) -> float:
66
+ inputs = blip_processor(images=img, return_tensors="pt").to(device)
67
  out = blip_model.generate(**inputs)
68
  caption = blip_processor.decode(out[0], skip_special_tokens=True)
69
  return compute_clip_score(img, caption)