dnbr-tagger-preview1-demo / src /streamlit_app.py
gustproof's picture
Update src/streamlit_app.py
6851e01 verified
import streamlit as st
import numpy as np
from PIL import Image
import io # Used to handle image bytes
import torch
import timm
import json
from torchvision.transforms.v2 import (
ToImage,
Compose,
ToDtype,
Normalize,
)
import pandas as pd
import requests
st.set_page_config(layout="wide")
device = ["cpu", "cuda"][torch.cuda.is_available()]
def NativeResize(patch_size, seq_len_range):
p, lo, hi = patch_size, *seq_len_range
refs = sorted(
[
(i / j, i * p, j * p)
for i in range(4, 100)
for j in range(4, 100)
if 0.33 <= i / j <= 3 and lo <= i * j < hi
]
)
def get_ratio(r):
return min(refs, key=lambda rr: max(r, rr[0]) / min(r, rr[0]) - 1)
def f(im: Image.Image):
w, h = im.size
_, sw, sh = get_ratio(w / h)
return im.resize((sw, sh), resample=Image.Resampling.BICUBIC)
return f
def load_json_from_url(url):
try:
response = requests.get(url)
response.raise_for_status() # Raise an exception for bad status codes
parsed_json = json.loads(response.text)
return parsed_json
except requests.exceptions.RequestException as e:
print(f"Error fetching data from URL: {e}")
return None
except json.JSONDecodeError as e:
print(f"Error decoding JSON data: {e}")
return None
@st.cache_data
def load_tags():
freqs = load_json_from_url("https://huggingface.co/gustproof/dnbr-tagger-preview1/raw/main/freqs.json")
freqs = [*freqs, (("PLACEHOLDER", 0))]
return freqs
tags = load_tags()
@st.cache_resource
def load_model():
torch.set_grad_enabled(False)
model = (
timm.create_model(
"hf_hub:gustproof/dnbr-tagger-preview1",
pretrained=True,
dynamic_img_size=True,
)
.eval()
.to(device)
)
print("loaded model")
tf = Compose(
[
ToImage(),
ToDtype(torch.float, scale=True),
Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
]
)
class Model:
def __init__(self):
self.class_names = load_tags()
def predict(self, img):
x = tf(img).unsqueeze(0).to(device)
probs = model(x).squeeze(0).sigmoid().cpu()
return probs
return Model()
model = load_model()
# --- Streamlit App Layout ---
st.title("Tagger demo")
st.write("Model: [gustproof/dnbr-tagger-preview1](https://huggingface.co/gustproof/dnbr-tagger-preview1)")
st.write(
"Upload an image to see predicted labels."
)
st.write("---")
# --- Sidebar for Controls ---
st.sidebar.header("Settings")
# Confidence Threshold Slider
confidence_threshold = st.sidebar.slider(
"Threshold (recommended: ~0.4-~0.6)",
min_value=0.0,
max_value=1.0,
value=0.5, # Default threshold
step=0.01,
)
# --- Main Area ---
col1, col2 = st.columns(2)
with col1:
st.header("Upload Image")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
if uploaded_file is not None:
# Read the image bytes
image_bytes = uploaded_file.getvalue()
# Display the uploaded image
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image = NativeResize(14, (270, 301))(image)
w, h = image.size
st.image(image, caption=f"Resized: {w}x{h}")
except Exception as e:
st.error(f"Error opening image: {e}")
st.warning("Please upload a valid image file")
uploaded_file = None # Reset uploaded_file so processing stops
with col2:
st.header("Predictions")
if uploaded_file is not None:
with st.spinner("Computing..."):
try:
scores = model.predict(image)
filtered_results = [
(i, p) for i, p in enumerate(scores) if p >= confidence_threshold
]
filtered_results.sort(key=lambda x: x[1], reverse=True)
if filtered_results:
get_category = lambda ti: (
"Rating" if ti < 4 else "General" if ti < 8856 else "Character"
)
df = pd.DataFrame(
[
(i, tags[ti][0], f"{p:.4f}", get_category(ti), tags[ti][1])
for i, (ti, p) in enumerate(filtered_results[:200], 1)
],
columns=[
"Rank",
"Label",
"Score",
"Category",
"Dataset frequency",
],
)
st.dataframe(
df,
hide_index=True,
column_config={
"Dataset frequency": st.column_config.NumberColumn(
format="localized"
)
},
)
else:
st.info("No labels meet the current threshold.")
except Exception as e:
st.error("An error occurred during prediction or processing:")
st.exception(e) # Shows the full traceback
else:
st.info("Upload an image using the panel on the left to see predictions.")