| | import io |
| | import random |
| | from typing import List, Tuple |
| |
|
| | import aiohttp |
| | import panel as pn |
| | from PIL import Image |
| | from transformers import CLIPModel, CLIPProcessor |
| |
|
| | pn.extension(design="bootstrap", sizing_mode="stretch_width") |
| |
|
| | ICON_URLS = { |
| | "brand-github": "https://github.com/holoviz/panel", |
| | "brand-twitter": "https://twitter.com/Panel_Org", |
| | "brand-linkedin": "https://www.linkedin.com/company/panel-org", |
| | "message-circle": "https://discourse.holoviz.org/", |
| | "brand-discord": "https://discord.gg/AXRHnJU6sP", |
| | } |
| |
|
| |
|
| | async def random_url(_): |
| | pet = random.choice(["cat", "dog"]) |
| | api_url = f"https://api.the{pet}api.com/v1/images/search" |
| | async with aiohttp.ClientSession() as session: |
| | async with session.get(api_url) as resp: |
| | return (await resp.json())[0]["url"] |
| |
|
| |
|
| | @pn.cache |
| | def load_processor_model( |
| | processor_name: str, model_name: str |
| | ) -> Tuple[CLIPProcessor, CLIPModel]: |
| | processor = CLIPProcessor.from_pretrained(processor_name) |
| | model = CLIPModel.from_pretrained(model_name) |
| | return processor, model |
| |
|
| |
|
| | async def open_image_url(image_url: str) -> Image: |
| | async with aiohttp.ClientSession() as session: |
| | async with session.get(image_url) as resp: |
| | return Image.open(io.BytesIO(await resp.read())) |
| |
|
| |
|
| | def get_similarity_scores(class_items: List[str], image: Image) -> List[float]: |
| | processor, model = load_processor_model( |
| | "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32" |
| | ) |
| | inputs = processor( |
| | text=class_items, |
| | images=[image], |
| | return_tensors="pt", |
| | ) |
| | outputs = model(**inputs) |
| | logits_per_image = outputs.logits_per_image |
| | class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy() |
| | return class_likelihoods[0] |
| |
|
| |
|
| | async def process_inputs(class_names: List[str], image_url: str): |
| | """ |
| | High level function that takes in the user inputs and returns the |
| | classification results as panel objects. |
| | """ |
| | try: |
| | main.disabled = True |
| | if not image_url: |
| | yield "##### β οΈ Provide an image URL" |
| | return |
| | |
| | yield "##### β Fetching image and running model..." |
| | try: |
| | pil_img = await open_image_url(image_url) |
| | img = pn.pane.Image(pil_img, height=400, align="center") |
| | except Exception as e: |
| | yield f"##### π Something went wrong, please try a different URL!" |
| | return |
| | |
| | class_items = class_names.split(",") |
| | class_likelihoods = get_similarity_scores(class_items, pil_img) |
| | |
| | |
| | results = pn.Column("##### π Here are the results!", img) |
| | |
| | for class_item, class_likelihood in zip(class_items, class_likelihoods): |
| | row_label = pn.widgets.StaticText( |
| | name=class_item.strip(), value=f"{class_likelihood:.2%}", align="center" |
| | ) |
| | row_bar = pn.indicators.Progress( |
| | value=int(class_likelihood * 100), |
| | sizing_mode="stretch_width", |
| | bar_color="secondary", |
| | margin=(0, 10), |
| | design=pn.theme.Material, |
| | ) |
| | results.append(pn.Column(row_label, row_bar)) |
| | yield results |
| | finally: |
| | main.disabled = False |
| |
|
| |
|
| | |
| | randomize_url = pn.widgets.Button(name="Randomize URL", align="end") |
| |
|
| | image_url = pn.widgets.TextInput( |
| | name="Image URL to classify", |
| | value=pn.bind(random_url, randomize_url), |
| | ) |
| | class_names = pn.widgets.TextInput( |
| | name="Comma separated class names", |
| | placeholder="Enter possible class names, e.g. cat, dog", |
| | value="cat, dog, parrot", |
| | ) |
| |
|
| | input_widgets = pn.Column( |
| | "##### π Click randomize or paste a URL to start classifying!", |
| | pn.Row(image_url, randomize_url), |
| | class_names, |
| | ) |
| |
|
| | |
| | interactive_result = pn.panel( |
| | pn.bind(process_inputs, image_url=image_url, class_names=class_names), |
| | height=600, |
| | ) |
| |
|
| | |
| | footer_row = pn.Row(pn.Spacer(), align="center") |
| | for icon, url in ICON_URLS.items(): |
| | href_button = pn.widgets.Button(icon=icon, width=35, height=35) |
| | href_button.js_on_click(code=f"window.open('{url}')") |
| | footer_row.append(href_button) |
| | footer_row.append(pn.Spacer()) |
| |
|
| | |
| | main = pn.WidgetBox( |
| | input_widgets, |
| | interactive_result, |
| | footer_row, |
| | ) |
| |
|
| | title = "Panel Demo - Image Classification" |
| | pn.template.BootstrapTemplate( |
| | title=title, |
| | main=main, |
| | main_max_width="min(50%, 698px)", |
| | header_background="#F08080", |
| | ).servable(title=title) |