Spaces:
Runtime error
Runtime error
Create clip_superior
Browse files- clip_superior +99 -0
clip_superior
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
from clip_interrogator import Config, Interrogator, list_caption_models, list_clip_models
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import gradio as gr
|
| 8 |
+
except ImportError:
|
| 9 |
+
print("Gradio is not installed, please install it with 'pip install gradio'")
|
| 10 |
+
exit(1)
|
| 11 |
+
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument("--lowvram", action='store_true', help="Optimize settings for low VRAM")
|
| 14 |
+
parser.add_argument('-s', '--share', action='store_true', help='Create a public link')
|
| 15 |
+
args = parser.parse_args()
|
| 16 |
+
|
| 17 |
+
if not torch.cuda.is_available():
|
| 18 |
+
print("CUDA is not available, using CPU. Warning: this will be very slow!")
|
| 19 |
+
|
| 20 |
+
config = Config(cache_path="cache")
|
| 21 |
+
if args.lowvram:
|
| 22 |
+
config.apply_low_vram_defaults()
|
| 23 |
+
ci = Interrogator(config)
|
| 24 |
+
|
| 25 |
+
def image_analysis(image, clip_model_name):
|
| 26 |
+
if clip_model_name != ci.config.clip_model_name:
|
| 27 |
+
ci.config.clip_model_name = clip_model_name
|
| 28 |
+
ci.load_clip_model()
|
| 29 |
+
|
| 30 |
+
image = image.convert('RGB')
|
| 31 |
+
image_features = ci.image_to_features(image)
|
| 32 |
+
|
| 33 |
+
top_mediums = ci.mediums.rank(image_features, 5)
|
| 34 |
+
top_artists = ci.artists.rank(image_features, 5)
|
| 35 |
+
top_movements = ci.movements.rank(image_features, 5)
|
| 36 |
+
top_trendings = ci.trendings.rank(image_features, 5)
|
| 37 |
+
top_flavors = ci.flavors.rank(image_features, 5)
|
| 38 |
+
|
| 39 |
+
medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}
|
| 40 |
+
artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}
|
| 41 |
+
movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}
|
| 42 |
+
trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}
|
| 43 |
+
flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}
|
| 44 |
+
|
| 45 |
+
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
|
| 46 |
+
|
| 47 |
+
def image_to_prompt(image, mode, clip_model_name, blip_model_name):
|
| 48 |
+
if blip_model_name != ci.config.caption_model_name:
|
| 49 |
+
ci.config.caption_model_name = blip_model_name
|
| 50 |
+
ci.load_caption_model()
|
| 51 |
+
|
| 52 |
+
if clip_model_name != ci.config.clip_model_name:
|
| 53 |
+
ci.config.clip_model_name = clip_model_name
|
| 54 |
+
ci.load_clip_model()
|
| 55 |
+
|
| 56 |
+
image = image.convert('RGB')
|
| 57 |
+
if mode == 'best':
|
| 58 |
+
return ci.interrogate(image)
|
| 59 |
+
elif mode == 'classic':
|
| 60 |
+
return ci.interrogate_classic(image)
|
| 61 |
+
elif mode == 'fast':
|
| 62 |
+
return ci.interrogate_fast(image)
|
| 63 |
+
elif mode == 'negative':
|
| 64 |
+
return ci.interrogate_negative(image)
|
| 65 |
+
|
| 66 |
+
def prompt_tab():
|
| 67 |
+
with gr.Column():
|
| 68 |
+
with gr.Row():
|
| 69 |
+
image = gr.Image(type='pil', label="Image")
|
| 70 |
+
with gr.Column():
|
| 71 |
+
mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')
|
| 72 |
+
clip_model = gr.Dropdown(list_clip_models(), value=ci.config.clip_model_name, label='CLIP Model')
|
| 73 |
+
blip_model = gr.Dropdown(list_caption_models(), value=ci.config.caption_model_name, label='Caption Model')
|
| 74 |
+
prompt = gr.Textbox(label="Prompt")
|
| 75 |
+
button = gr.Button("Generate prompt")
|
| 76 |
+
button.click(image_to_prompt, inputs=[image, mode, clip_model, blip_model], outputs=prompt)
|
| 77 |
+
|
| 78 |
+
def analyze_tab():
|
| 79 |
+
with gr.Column():
|
| 80 |
+
with gr.Row():
|
| 81 |
+
image = gr.Image(type='pil', label="Image")
|
| 82 |
+
model = gr.Dropdown(list_clip_models(), value='ViT-L-14/openai', label='CLIP Model')
|
| 83 |
+
with gr.Row():
|
| 84 |
+
medium = gr.Label(label="Medium", num_top_classes=5)
|
| 85 |
+
artist = gr.Label(label="Artist", num_top_classes=5)
|
| 86 |
+
movement = gr.Label(label="Movement", num_top_classes=5)
|
| 87 |
+
trending = gr.Label(label="Trending", num_top_classes=5)
|
| 88 |
+
flavor = gr.Label(label="Flavor", num_top_classes=5)
|
| 89 |
+
button = gr.Button("Analyze")
|
| 90 |
+
button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor])
|
| 91 |
+
|
| 92 |
+
with gr.Blocks() as ui:
|
| 93 |
+
gr.Markdown("# <center>🕵️♂️ CLIP Interrogator 🕵️♂️</center>")
|
| 94 |
+
with gr.Tab("Prompt"):
|
| 95 |
+
prompt_tab()
|
| 96 |
+
with gr.Tab("Analyze"):
|
| 97 |
+
analyze_tab()
|
| 98 |
+
|
| 99 |
+
ui.launch(show_api=False, debug=True, share=args.share)
|