Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ import gradio as gr
|
|
| 6 |
from PIL import Image
|
| 7 |
from sd_parsers import ParserManager
|
| 8 |
from torchvision import transforms
|
| 9 |
-
from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration
|
| 10 |
import lpips
|
| 11 |
import piq
|
| 12 |
import plotly.express as px
|
|
@@ -18,31 +18,33 @@ import plotly.express as px
|
|
| 18 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 19 |
|
| 20 |
# CLIP for prompt alignment & aesthetics
|
| 21 |
-
clip_model = CLIPModel.from_pretrained(
|
|
|
|
|
|
|
| 22 |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 23 |
|
| 24 |
-
# BLIP-2 for caption generation (
|
|
|
|
| 25 |
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
| 26 |
blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
| 27 |
-
"Salesforce/blip2-flan-t5-xl",
|
| 28 |
-
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
# LPIPS for diversity
|
|
|
|
| 31 |
|
| 32 |
# --------------------
|
| 33 |
# Helper Functions
|
| 34 |
# --------------------
|
| 35 |
|
| 36 |
-
def extract_metadata(
|
| 37 |
-
"""
|
| 38 |
parser = ParserManager()
|
| 39 |
-
|
| 40 |
-
with open(tmp_path, 'wb') as tmp:
|
| 41 |
-
tmp.write(image_bytes)
|
| 42 |
-
info = parser.parse(tmp_path)
|
| 43 |
prompt = info.prompts[0].value if info.prompts else ''
|
| 44 |
model_name = info.model_name or ''
|
| 45 |
-
os.remove(tmp_path)
|
| 46 |
return prompt, model_name
|
| 47 |
|
| 48 |
# Image preprocessing transform
|
|
@@ -90,9 +92,9 @@ def analyze_images(files):
|
|
| 90 |
imgs_by_model = {}
|
| 91 |
|
| 92 |
for f in files:
|
| 93 |
-
|
| 94 |
-
img = Image.open(
|
| 95 |
-
prompt, model = extract_metadata(
|
| 96 |
|
| 97 |
clip_score = compute_clip_score(img, prompt)
|
| 98 |
cap_sim = compute_caption_similarity(img, prompt)
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
from sd_parsers import ParserManager
|
| 8 |
from torchvision import transforms
|
| 9 |
+
from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration, BitsAndBytesConfig
|
| 10 |
import lpips
|
| 11 |
import piq
|
| 12 |
import plotly.express as px
|
|
|
|
| 18 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 19 |
|
| 20 |
# CLIP for prompt alignment & aesthetics
|
| 21 |
+
clip_model = CLIPModel.from_pretrained(
|
| 22 |
+
"openai/clip-vit-base-patch32"
|
| 23 |
+
).to(device)
|
| 24 |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 25 |
|
| 26 |
+
# BLIP-2 for caption generation (8-bit quantized / fp8 proxy)
|
| 27 |
+
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 28 |
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
| 29 |
blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
| 30 |
+
"Salesforce/blip2-flan-t5-xl",
|
| 31 |
+
quantization_config=bnb_config,
|
| 32 |
+
device_map="auto"
|
| 33 |
+
)
|
| 34 |
|
| 35 |
+
# LPIPS for diversity
|
| 36 |
+
lpips_model = lpips.LPIPS(net='alex').to(device)
|
| 37 |
|
| 38 |
# --------------------
|
| 39 |
# Helper Functions
|
| 40 |
# --------------------
|
| 41 |
|
| 42 |
+
def extract_metadata(file):
|
| 43 |
+
"""Extract prompt and model name using sd-parsers from file path."""
|
| 44 |
parser = ParserManager()
|
| 45 |
+
info = parser.parse(file.name)
|
|
|
|
|
|
|
|
|
|
| 46 |
prompt = info.prompts[0].value if info.prompts else ''
|
| 47 |
model_name = info.model_name or ''
|
|
|
|
| 48 |
return prompt, model_name
|
| 49 |
|
| 50 |
# Image preprocessing transform
|
|
|
|
| 92 |
imgs_by_model = {}
|
| 93 |
|
| 94 |
for f in files:
|
| 95 |
+
# use f.name path instead of read() to avoid NamedString issues
|
| 96 |
+
img = Image.open(f.name).convert('RGB')
|
| 97 |
+
prompt, model = extract_metadata(f)
|
| 98 |
|
| 99 |
clip_score = compute_clip_score(img, prompt)
|
| 100 |
cap_sim = compute_caption_similarity(img, prompt)
|