Update app.py
Browse files
app.py
CHANGED
|
@@ -18,12 +18,10 @@ import plotly.express as px
|
|
| 18 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 19 |
|
| 20 |
# CLIP for prompt alignment & aesthetics
|
| 21 |
-
clip_model = CLIPModel.from_pretrained(
|
| 22 |
-
"openai/clip-vit-base-patch32"
|
| 23 |
-
).to(device)
|
| 24 |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 25 |
|
| 26 |
-
# BLIP-2 for caption generation: 8-bit if GPU available, else
|
| 27 |
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
| 28 |
if torch.cuda.is_available():
|
| 29 |
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
@@ -33,12 +31,10 @@ if torch.cuda.is_available():
|
|
| 33 |
device_map="auto"
|
| 34 |
)
|
| 35 |
else:
|
| 36 |
-
# CPU-only environment: load half precision
|
| 37 |
blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
| 38 |
"Salesforce/blip2-flan-t5-xl",
|
| 39 |
torch_dtype=torch.float16
|
| 40 |
-
)
|
| 41 |
-
blip_model.to(device)
|
| 42 |
|
| 43 |
# LPIPS for diversity
|
| 44 |
lpips_model = lpips.LPIPS(net='alex').to(device)
|
|
@@ -51,11 +47,8 @@ def extract_metadata(file):
|
|
| 51 |
"""Extract prompt and model name using sd-parsers from file path."""
|
| 52 |
parser = ParserManager()
|
| 53 |
info = parser.parse(file.name)
|
| 54 |
-
# prompts list
|
| 55 |
prompt = info.prompts[0].value if info.prompts else ''
|
| 56 |
-
# models list may contain model identifiers
|
| 57 |
if hasattr(info, 'models') and info.models:
|
| 58 |
-
# info.models may be list of strings or objects
|
| 59 |
first = info.models[0]
|
| 60 |
model_name = first.name if hasattr(first, 'name') else str(first)
|
| 61 |
else:
|
|
@@ -63,7 +56,7 @@ def extract_metadata(file):
|
|
| 63 |
return prompt, model_name
|
| 64 |
|
| 65 |
# Image preprocessing transform
|
| 66 |
-
preprocess = transforms.Compose([
|
| 67 |
transforms.Resize((224, 224)),
|
| 68 |
transforms.ToTensor(),
|
| 69 |
transforms.Normalize(
|
|
@@ -72,6 +65,10 @@ preprocess = transforms.Compose([ transforms.Compose([
|
|
| 72 |
)
|
| 73 |
])
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
def compute_clip_score(img: Image.Image, text: str) -> float:
|
| 76 |
inputs = clip_processor(text=[text], images=img, return_tensors="pt", padding=True).to(device)
|
| 77 |
outputs = clip_model(**inputs)
|
|
@@ -110,7 +107,7 @@ def analyze_images(files):
|
|
| 110 |
img = Image.open(f.name).convert('RGB')
|
| 111 |
prompt, model = extract_metadata(f)
|
| 112 |
|
| 113 |
-
|
| 114 |
cap_sim = compute_caption_similarity(img, prompt)
|
| 115 |
brisque, niqe = compute_iqa_metrics(img)
|
| 116 |
aesthetic = compute_clip_score(img, "a beautiful high quality image")
|
|
@@ -118,7 +115,7 @@ def analyze_images(files):
|
|
| 118 |
records.append({
|
| 119 |
'model': model,
|
| 120 |
'prompt': prompt,
|
| 121 |
-
'clip_score':
|
| 122 |
'caption_sim': cap_sim,
|
| 123 |
'brisque': brisque,
|
| 124 |
'niqe': niqe,
|
|
@@ -153,14 +150,13 @@ def analyze_images(files):
|
|
| 153 |
# --------------------
|
| 154 |
|
| 155 |
def plot_metrics(agg: pd.DataFrame):
|
| 156 |
-
|
| 157 |
agg,
|
| 158 |
x='model',
|
| 159 |
y=['aesthetic_mean', 'clip_score_mean', 'caption_sim_mean', 'diversity'],
|
| 160 |
barmode='group',
|
| 161 |
title='Сравнение моделей по метрикам'
|
| 162 |
)
|
| 163 |
-
return fig
|
| 164 |
|
| 165 |
# --------------------
|
| 166 |
# Gradio Interface
|
|
@@ -178,7 +174,10 @@ with gr.Blocks() as demo:
|
|
| 178 |
with gr.Row():
|
| 179 |
input_files = gr.File(file_count="multiple", label="Выберите PNG файлы")
|
| 180 |
output_table = gr.Dataframe(
|
| 181 |
-
headers=[
|
|
|
|
|
|
|
|
|
|
| 182 |
label="Сводная таблица"
|
| 183 |
)
|
| 184 |
|
|
|
|
| 18 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 19 |
|
| 20 |
# CLIP for prompt alignment & aesthetics
|
| 21 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
|
|
|
|
|
|
| 22 |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 23 |
|
| 24 |
+
# BLIP-2 for caption generation: 8-bit if GPU available, else half precision
|
| 25 |
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
| 26 |
if torch.cuda.is_available():
|
| 27 |
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
|
|
| 31 |
device_map="auto"
|
| 32 |
)
|
| 33 |
else:
|
|
|
|
| 34 |
blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
| 35 |
"Salesforce/blip2-flan-t5-xl",
|
| 36 |
torch_dtype=torch.float16
|
| 37 |
+
).to(device)
|
|
|
|
| 38 |
|
| 39 |
# LPIPS for diversity
|
| 40 |
lpips_model = lpips.LPIPS(net='alex').to(device)
|
|
|
|
| 47 |
"""Extract prompt and model name using sd-parsers from file path."""
|
| 48 |
parser = ParserManager()
|
| 49 |
info = parser.parse(file.name)
|
|
|
|
| 50 |
prompt = info.prompts[0].value if info.prompts else ''
|
|
|
|
| 51 |
if hasattr(info, 'models') and info.models:
|
|
|
|
| 52 |
first = info.models[0]
|
| 53 |
model_name = first.name if hasattr(first, 'name') else str(first)
|
| 54 |
else:
|
|
|
|
| 56 |
return prompt, model_name
|
| 57 |
|
| 58 |
# Image preprocessing transform
|
| 59 |
+
preprocess = transforms.Compose([
|
| 60 |
transforms.Resize((224, 224)),
|
| 61 |
transforms.ToTensor(),
|
| 62 |
transforms.Normalize(
|
|
|
|
| 65 |
)
|
| 66 |
])
|
| 67 |
|
| 68 |
+
# --------------------
|
| 69 |
+
# Metric Computations
|
| 70 |
+
# --------------------
|
| 71 |
+
|
| 72 |
def compute_clip_score(img: Image.Image, text: str) -> float:
|
| 73 |
inputs = clip_processor(text=[text], images=img, return_tensors="pt", padding=True).to(device)
|
| 74 |
outputs = clip_model(**inputs)
|
|
|
|
| 107 |
img = Image.open(f.name).convert('RGB')
|
| 108 |
prompt, model = extract_metadata(f)
|
| 109 |
|
| 110 |
+
cs = compute_clip_score(img, prompt)
|
| 111 |
cap_sim = compute_caption_similarity(img, prompt)
|
| 112 |
brisque, niqe = compute_iqa_metrics(img)
|
| 113 |
aesthetic = compute_clip_score(img, "a beautiful high quality image")
|
|
|
|
| 115 |
records.append({
|
| 116 |
'model': model,
|
| 117 |
'prompt': prompt,
|
| 118 |
+
'clip_score': cs,
|
| 119 |
'caption_sim': cap_sim,
|
| 120 |
'brisque': brisque,
|
| 121 |
'niqe': niqe,
|
|
|
|
| 150 |
# --------------------
|
| 151 |
|
| 152 |
def plot_metrics(agg: pd.DataFrame):
|
| 153 |
+
return px.bar(
|
| 154 |
agg,
|
| 155 |
x='model',
|
| 156 |
y=['aesthetic_mean', 'clip_score_mean', 'caption_sim_mean', 'diversity'],
|
| 157 |
barmode='group',
|
| 158 |
title='Сравнение моделей по метрикам'
|
| 159 |
)
|
|
|
|
| 160 |
|
| 161 |
# --------------------
|
| 162 |
# Gradio Interface
|
|
|
|
| 174 |
with gr.Row():
|
| 175 |
input_files = gr.File(file_count="multiple", label="Выберите PNG файлы")
|
| 176 |
output_table = gr.Dataframe(
|
| 177 |
+
headers=[
|
| 178 |
+
"model", "clip_score_mean", "caption_sim_mean", "brisque_mean",
|
| 179 |
+
"niqe_mean", "aesthetic_mean", "diversity"
|
| 180 |
+
],
|
| 181 |
label="Сводная таблица"
|
| 182 |
)
|
| 183 |
|