import streamlit as st import numpy as np from PIL import Image import io # Used to handle image bytes import torch import timm import json from torchvision.transforms.v2 import ( ToImage, Compose, ToDtype, Normalize, ) import pandas as pd import requests st.set_page_config(layout="wide") device = ["cpu", "cuda"][torch.cuda.is_available()] def NativeResize(patch_size, seq_len_range): p, lo, hi = patch_size, *seq_len_range refs = sorted( [ (i / j, i * p, j * p) for i in range(4, 100) for j in range(4, 100) if 0.33 <= i / j <= 3 and lo <= i * j < hi ] ) def get_ratio(r): return min(refs, key=lambda rr: max(r, rr[0]) / min(r, rr[0]) - 1) def f(im: Image.Image): w, h = im.size _, sw, sh = get_ratio(w / h) return im.resize((sw, sh), resample=Image.Resampling.BICUBIC) return f def load_json_from_url(url): try: response = requests.get(url) response.raise_for_status() # Raise an exception for bad status codes parsed_json = json.loads(response.text) return parsed_json except requests.exceptions.RequestException as e: print(f"Error fetching data from URL: {e}") return None except json.JSONDecodeError as e: print(f"Error decoding JSON data: {e}") return None @st.cache_data def load_tags(): freqs = load_json_from_url("https://huggingface.co/gustproof/dnbr-tagger-preview1/raw/main/freqs.json") freqs = [*freqs, (("PLACEHOLDER", 0))] return freqs tags = load_tags() @st.cache_resource def load_model(): torch.set_grad_enabled(False) model = ( timm.create_model( "hf_hub:gustproof/dnbr-tagger-preview1", pretrained=True, dynamic_img_size=True, ) .eval() .to(device) ) print("loaded model") tf = Compose( [ ToImage(), ToDtype(torch.float, scale=True), Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]), ] ) class Model: def __init__(self): self.class_names = load_tags() def predict(self, img): x = tf(img).unsqueeze(0).to(device) probs = model(x).squeeze(0).sigmoid().cpu() return probs return Model() model = load_model() # --- Streamlit App Layout --- st.title("Tagger demo") st.write("Model: [gustproof/dnbr-tagger-preview1](https://huggingface.co/gustproof/dnbr-tagger-preview1)") st.write( "Upload an image to see predicted labels." ) st.write("---") # --- Sidebar for Controls --- st.sidebar.header("Settings") # Confidence Threshold Slider confidence_threshold = st.sidebar.slider( "Threshold (recommended: ~0.4-~0.6)", min_value=0.0, max_value=1.0, value=0.5, # Default threshold step=0.01, ) # --- Main Area --- col1, col2 = st.columns(2) with col1: st.header("Upload Image") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"]) if uploaded_file is not None: # Read the image bytes image_bytes = uploaded_file.getvalue() # Display the uploaded image try: image = Image.open(io.BytesIO(image_bytes)).convert("RGB") image = NativeResize(14, (270, 301))(image) w, h = image.size st.image(image, caption=f"Resized: {w}x{h}") except Exception as e: st.error(f"Error opening image: {e}") st.warning("Please upload a valid image file") uploaded_file = None # Reset uploaded_file so processing stops with col2: st.header("Predictions") if uploaded_file is not None: with st.spinner("Computing..."): try: scores = model.predict(image) filtered_results = [ (i, p) for i, p in enumerate(scores) if p >= confidence_threshold ] filtered_results.sort(key=lambda x: x[1], reverse=True) if filtered_results: get_category = lambda ti: ( "Rating" if ti < 4 else "General" if ti < 8856 else "Character" ) df = pd.DataFrame( [ (i, tags[ti][0], f"{p:.4f}", get_category(ti), tags[ti][1]) for i, (ti, p) in enumerate(filtered_results[:200], 1) ], columns=[ "Rank", "Label", "Score", "Category", "Dataset frequency", ], ) st.dataframe( df, hide_index=True, column_config={ "Dataset frequency": st.column_config.NumberColumn( format="localized" ) }, ) else: st.info("No labels meet the current threshold.") except Exception as e: st.error("An error occurred during prediction or processing:") st.exception(e) # Shows the full traceback else: st.info("Upload an image using the panel on the left to see predictions.")