Spaces:
Running
Running
| import aiohttp | |
| import io | |
| import random | |
| import panel as pn | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel | |
| from typing import List, Tuple | |
| pn.extension(design='bootstrap', sizing_mode="stretch_width") | |
| async def random_url(_): | |
| api_url = random.choice([ | |
| "https://api.thecatapi.com/v1/images/search", | |
| "https://api.thedogapi.com/v1/images/search" | |
| ]) | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(api_url) as resp: | |
| return (await resp.json())[0]["url"] | |
| def load_processor_model( | |
| processor_name: str, model_name: str | |
| ) -> Tuple[CLIPProcessor, CLIPModel]: | |
| processor = CLIPProcessor.from_pretrained(processor_name) | |
| model = CLIPModel.from_pretrained(model_name) | |
| return processor, model | |
| async def open_image_url(image_url: str) -> Image: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(image_url) as resp: | |
| return Image.open(io.BytesIO(await resp.read())) | |
| def get_similarity_scores(class_items: List[str], image: Image) -> List[float]: | |
| processor, model = load_processor_model( | |
| "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32" | |
| ) | |
| inputs = processor( | |
| text=class_items, | |
| images=[image], | |
| return_tensors="pt", # pytorch tensors | |
| ) | |
| outputs = model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy() | |
| return class_likelihoods[0] | |
| async def process_inputs(class_names: List[str], image_url: str): | |
| """ | |
| High level function that takes in the user inputs and returns the | |
| classification results as panel objects. | |
| """ | |
| if not image_url: | |
| yield '## Provide an image URL' | |
| return | |
| yield '## Fetching image and running model β' | |
| pil_img = await open_image_url(image_url) | |
| img = pn.pane.Image(pil_img, height=400, align='center') | |
| class_items = class_names.split(",") | |
| class_likelihoods = get_similarity_scores(class_items, pil_img) | |
| # build the results column | |
| results = pn.Column("## π Here are the results!", img) | |
| for class_item, class_likelihood in zip(class_items, class_likelihoods): | |
| row_label = pn.widgets.StaticText( | |
| name=class_item.strip(), value=f"{class_likelihood:.2%}", align='center' | |
| ) | |
| row_bar = pn.indicators.Progress( | |
| value=int(class_likelihood * 100), | |
| sizing_mode="stretch_width", | |
| bar_color="secondary", | |
| margin=(0, 10), | |
| design=pn.theme.Material | |
| ) | |
| results.append(pn.Column(row_label, row_bar)) | |
| yield results | |
| # create widgets | |
| randomize_url = pn.widgets.Button(name="Randomize URL", align="end") | |
| image_url = pn.widgets.TextInput( | |
| name="Image URL to classify", | |
| value=pn.bind(random_url, randomize_url), | |
| ) | |
| class_names = pn.widgets.TextInput( | |
| name="Comma separated class names", | |
| placeholder="Enter possible class names, e.g. cat, dog", | |
| value="cat, dog, parrot", | |
| ) | |
| input_widgets = pn.Column( | |
| "## π Click randomize or paste a URL to start classifying!", | |
| pn.Row(image_url, randomize_url), | |
| class_names, | |
| ) | |
| # add interactivity | |
| interactive_result = pn.bind( | |
| process_inputs, image_url=image_url, class_names=class_names | |
| ) | |
| # create dashboard | |
| main = pn.WidgetBox( | |
| input_widgets, | |
| interactive_result, | |
| ) | |
| pn.template.BootstrapTemplate( | |
| title="Panel Image Classification Demo", | |
| main=main, | |
| main_max_width="min(50%, 698px)", | |
| header_background="#F08080", | |
| ).servable(title="Panel Image Classification Demo"); |