Update app.py
Browse files
app.py
CHANGED
|
@@ -14,20 +14,20 @@ import plotly.express as px
|
|
| 14 |
# --------------------
|
| 15 |
# Setup Models
|
| 16 |
# --------------------
|
|
|
|
| 17 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 18 |
|
| 19 |
# CLIP for prompt alignment & aesthetics
|
| 20 |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
| 21 |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 22 |
|
| 23 |
-
# BLIP-2 for caption generation
|
| 24 |
-
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
| 25 |
blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
| 26 |
"Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16
|
| 27 |
).to(device)
|
| 28 |
|
| 29 |
-
# LPIPS for diversity
|
| 30 |
-
lpips_model = lpips.LPIPS(net='alex').to(device)
|
| 31 |
|
| 32 |
# --------------------
|
| 33 |
# Helper Functions
|
|
@@ -36,20 +36,23 @@ lpips_model = lpips.LPIPS(net='alex').to(device)
|
|
| 36 |
def extract_metadata(image_bytes):
|
| 37 |
"""Extracts prompt and model name from image bytes using sd-parsers."""
|
| 38 |
parser = ParserManager()
|
| 39 |
-
|
|
|
|
| 40 |
tmp.write(image_bytes)
|
| 41 |
-
info = parser.parse(
|
| 42 |
prompt = info.prompts[0].value if info.prompts else ''
|
| 43 |
model_name = info.model_name or ''
|
| 44 |
-
os.remove(
|
| 45 |
return prompt, model_name
|
| 46 |
|
| 47 |
-
# Image preprocessing
|
| 48 |
preprocess = transforms.Compose([
|
| 49 |
transforms.Resize((224, 224)),
|
| 50 |
transforms.ToTensor(),
|
| 51 |
-
transforms.Normalize(
|
| 52 |
-
|
|
|
|
|
|
|
| 53 |
])
|
| 54 |
|
| 55 |
def compute_clip_score(img: Image.Image, text: str) -> float:
|
|
@@ -60,7 +63,7 @@ def compute_clip_score(img: Image.Image, text: str) -> float:
|
|
| 60 |
|
| 61 |
@torch.no_grad()
|
| 62 |
def compute_caption_similarity(img: Image.Image, prompt: str) -> float:
|
| 63 |
-
inputs = blip_processor(images=img, return_tensors="pt").to(device
|
| 64 |
out = blip_model.generate(**inputs)
|
| 65 |
caption = blip_processor.decode(out[0], skip_special_tokens=True)
|
| 66 |
return compute_clip_score(img, caption)
|
|
|
|
| 14 |
# --------------------
|
| 15 |
# Setup Models
|
| 16 |
# --------------------
|
| 17 |
+
|
| 18 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 19 |
|
| 20 |
# CLIP for prompt alignment & aesthetics
|
| 21 |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
| 22 |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 23 |
|
| 24 |
+
# BLIP-2 for caption generation (processor without .to)
|
| 25 |
+
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
| 26 |
blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
| 27 |
"Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16
|
| 28 |
).to(device)
|
| 29 |
|
| 30 |
+
# LPIPS for diversity\ nlpips_model = lpips.LPIPS(net='alex').to(device)
|
|
|
|
| 31 |
|
| 32 |
# --------------------
|
| 33 |
# Helper Functions
|
|
|
|
| 36 |
def extract_metadata(image_bytes):
|
| 37 |
"""Extracts prompt and model name from image bytes using sd-parsers."""
|
| 38 |
parser = ParserManager()
|
| 39 |
+
tmp_path = "temp.png"
|
| 40 |
+
with open(tmp_path, 'wb') as tmp:
|
| 41 |
tmp.write(image_bytes)
|
| 42 |
+
info = parser.parse(tmp_path)
|
| 43 |
prompt = info.prompts[0].value if info.prompts else ''
|
| 44 |
model_name = info.model_name or ''
|
| 45 |
+
os.remove(tmp_path)
|
| 46 |
return prompt, model_name
|
| 47 |
|
| 48 |
+
# Image preprocessing transform
|
| 49 |
preprocess = transforms.Compose([
|
| 50 |
transforms.Resize((224, 224)),
|
| 51 |
transforms.ToTensor(),
|
| 52 |
+
transforms.Normalize(
|
| 53 |
+
(0.48145466, 0.4578275, 0.40821073),
|
| 54 |
+
(0.26862954, 0.26130258, 0.27577711)
|
| 55 |
+
)
|
| 56 |
])
|
| 57 |
|
| 58 |
def compute_clip_score(img: Image.Image, text: str) -> float:
|
|
|
|
| 63 |
|
| 64 |
@torch.no_grad()
|
| 65 |
def compute_caption_similarity(img: Image.Image, prompt: str) -> float:
|
| 66 |
+
inputs = blip_processor(images=img, return_tensors="pt").to(device)
|
| 67 |
out = blip_model.generate(**inputs)
|
| 68 |
caption = blip_processor.decode(out[0], skip_special_tokens=True)
|
| 69 |
return compute_clip_score(img, caption)
|