VOIDER commited on
Commit
b1862d3
·
verified ·
1 Parent(s): 7a83d0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -16
app.py CHANGED
@@ -18,12 +18,10 @@ import plotly.express as px
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
 
20
  # CLIP for prompt alignment & aesthetics
21
- clip_model = CLIPModel.from_pretrained(
22
- "openai/clip-vit-base-patch32"
23
- ).to(device)
24
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
25
 
26
- # BLIP-2 for caption generation: 8-bit if GPU available, else float16
27
  blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
28
  if torch.cuda.is_available():
29
  bnb_config = BitsAndBytesConfig(load_in_8bit=True)
@@ -33,12 +31,10 @@ if torch.cuda.is_available():
33
  device_map="auto"
34
  )
35
  else:
36
- # CPU-only environment: load half precision
37
  blip_model = Blip2ForConditionalGeneration.from_pretrained(
38
  "Salesforce/blip2-flan-t5-xl",
39
  torch_dtype=torch.float16
40
- )
41
- blip_model.to(device)
42
 
43
  # LPIPS for diversity
44
  lpips_model = lpips.LPIPS(net='alex').to(device)
@@ -51,11 +47,8 @@ def extract_metadata(file):
51
  """Extract prompt and model name using sd-parsers from file path."""
52
  parser = ParserManager()
53
  info = parser.parse(file.name)
54
- # prompts list
55
  prompt = info.prompts[0].value if info.prompts else ''
56
- # models list may contain model identifiers
57
  if hasattr(info, 'models') and info.models:
58
- # info.models may be list of strings or objects
59
  first = info.models[0]
60
  model_name = first.name if hasattr(first, 'name') else str(first)
61
  else:
@@ -63,7 +56,7 @@ def extract_metadata(file):
63
  return prompt, model_name
64
 
65
  # Image preprocessing transform
66
- preprocess = transforms.Compose([ transforms.Compose([
67
  transforms.Resize((224, 224)),
68
  transforms.ToTensor(),
69
  transforms.Normalize(
@@ -72,6 +65,10 @@ preprocess = transforms.Compose([ transforms.Compose([
72
  )
73
  ])
74
 
 
 
 
 
75
  def compute_clip_score(img: Image.Image, text: str) -> float:
76
  inputs = clip_processor(text=[text], images=img, return_tensors="pt", padding=True).to(device)
77
  outputs = clip_model(**inputs)
@@ -110,7 +107,7 @@ def analyze_images(files):
110
  img = Image.open(f.name).convert('RGB')
111
  prompt, model = extract_metadata(f)
112
 
113
- clip_score = compute_clip_score(img, prompt)
114
  cap_sim = compute_caption_similarity(img, prompt)
115
  brisque, niqe = compute_iqa_metrics(img)
116
  aesthetic = compute_clip_score(img, "a beautiful high quality image")
@@ -118,7 +115,7 @@ def analyze_images(files):
118
  records.append({
119
  'model': model,
120
  'prompt': prompt,
121
- 'clip_score': clip_score,
122
  'caption_sim': cap_sim,
123
  'brisque': brisque,
124
  'niqe': niqe,
@@ -153,14 +150,13 @@ def analyze_images(files):
153
  # --------------------
154
 
155
  def plot_metrics(agg: pd.DataFrame):
156
- fig = px.bar(
157
  agg,
158
  x='model',
159
  y=['aesthetic_mean', 'clip_score_mean', 'caption_sim_mean', 'diversity'],
160
  barmode='group',
161
  title='Сравнение моделей по метрикам'
162
  )
163
- return fig
164
 
165
  # --------------------
166
  # Gradio Interface
@@ -178,7 +174,10 @@ with gr.Blocks() as demo:
178
  with gr.Row():
179
  input_files = gr.File(file_count="multiple", label="Выберите PNG файлы")
180
  output_table = gr.Dataframe(
181
- headers=["model", "clip_score_mean", "caption_sim_mean", "brisque_mean", "niqe_mean", "aesthetic_mean", "diversity"],
 
 
 
182
  label="Сводная таблица"
183
  )
184
 
 
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
 
20
  # CLIP for prompt alignment & aesthetics
21
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
 
 
22
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
23
 
24
+ # BLIP-2 for caption generation: 8-bit if GPU available, else half precision
25
  blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
26
  if torch.cuda.is_available():
27
  bnb_config = BitsAndBytesConfig(load_in_8bit=True)
 
31
  device_map="auto"
32
  )
33
  else:
 
34
  blip_model = Blip2ForConditionalGeneration.from_pretrained(
35
  "Salesforce/blip2-flan-t5-xl",
36
  torch_dtype=torch.float16
37
+ ).to(device)
 
38
 
39
  # LPIPS for diversity
40
  lpips_model = lpips.LPIPS(net='alex').to(device)
 
47
  """Extract prompt and model name using sd-parsers from file path."""
48
  parser = ParserManager()
49
  info = parser.parse(file.name)
 
50
  prompt = info.prompts[0].value if info.prompts else ''
 
51
  if hasattr(info, 'models') and info.models:
 
52
  first = info.models[0]
53
  model_name = first.name if hasattr(first, 'name') else str(first)
54
  else:
 
56
  return prompt, model_name
57
 
58
  # Image preprocessing transform
59
+ preprocess = transforms.Compose([
60
  transforms.Resize((224, 224)),
61
  transforms.ToTensor(),
62
  transforms.Normalize(
 
65
  )
66
  ])
67
 
68
+ # --------------------
69
+ # Metric Computations
70
+ # --------------------
71
+
72
  def compute_clip_score(img: Image.Image, text: str) -> float:
73
  inputs = clip_processor(text=[text], images=img, return_tensors="pt", padding=True).to(device)
74
  outputs = clip_model(**inputs)
 
107
  img = Image.open(f.name).convert('RGB')
108
  prompt, model = extract_metadata(f)
109
 
110
+ cs = compute_clip_score(img, prompt)
111
  cap_sim = compute_caption_similarity(img, prompt)
112
  brisque, niqe = compute_iqa_metrics(img)
113
  aesthetic = compute_clip_score(img, "a beautiful high quality image")
 
115
  records.append({
116
  'model': model,
117
  'prompt': prompt,
118
+ 'clip_score': cs,
119
  'caption_sim': cap_sim,
120
  'brisque': brisque,
121
  'niqe': niqe,
 
150
  # --------------------
151
 
152
  def plot_metrics(agg: pd.DataFrame):
153
+ return px.bar(
154
  agg,
155
  x='model',
156
  y=['aesthetic_mean', 'clip_score_mean', 'caption_sim_mean', 'diversity'],
157
  barmode='group',
158
  title='Сравнение моделей по метрикам'
159
  )
 
160
 
161
  # --------------------
162
  # Gradio Interface
 
174
  with gr.Row():
175
  input_files = gr.File(file_count="multiple", label="Выберите PNG файлы")
176
  output_table = gr.Dataframe(
177
+ headers=[
178
+ "model", "clip_score_mean", "caption_sim_mean", "brisque_mean",
179
+ "niqe_mean", "aesthetic_mean", "diversity"
180
+ ],
181
  label="Сводная таблица"
182
  )
183