webkal / app.py
Clemenz88's picture
Upload 2 files
1d7bad3 verified
import streamlit as st
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import torch
import numpy as np
@st.cache_resource
def load_model():
extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("timm/food101-vit-base-patch16-224")
return extractor, model
def predict(image, extractor, model):
inputs = extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
top_probs, top_idxs = torch.topk(probs, 3)
return [(model.config.id2label[idx.item()], prob.item()) for idx, prob in zip(top_idxs, top_probs)]
st.title("Kalorieestimat baseret på madbillede")
uploaded_file = st.file_uploader("Upload billede", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Dit billede", use_column_width=True)
extractor, model = load_model()
predictions = predict(image, extractor, model)
st.subheader("Modelgæt")
for label, prob in predictions:
st.write(f"{label}: {prob:.2%}")
if predictions[0][1] < 0.70:
st.warning("Sikkerheden er lav. Vælg venligst fødevare manuelt.")
fallback = st.selectbox("Vælg madtype", options=[label for label, _ in predictions] + ["Andet"])
st.write(f"Valgt: {fallback}")
else:
st.success(f"Detekteret: {predictions[0][0]}")
feedback = st.radio("Er dette korrekt?", ["Ja", "Nej", "Ved ikke"])
if st.button("Send feedback"):
st.info(f"Tak for din feedback: {feedback}")