|
|
|
|
|
import streamlit as st |
|
|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
|
from PIL import Image |
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") |
|
|
model = AutoModelForImageClassification.from_pretrained("timm/food101-vit-base-patch16-224") |
|
|
return extractor, model |
|
|
|
|
|
def predict(image, extractor, model): |
|
|
inputs = extractor(images=image, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] |
|
|
top_probs, top_idxs = torch.topk(probs, 3) |
|
|
return [(model.config.id2label[idx.item()], prob.item()) for idx, prob in zip(top_idxs, top_probs)] |
|
|
|
|
|
st.title("Kalorieestimat baseret på madbillede") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload billede", type=["jpg", "jpeg", "png"]) |
|
|
if uploaded_file: |
|
|
image = Image.open(uploaded_file).convert("RGB") |
|
|
st.image(image, caption="Dit billede", use_column_width=True) |
|
|
|
|
|
extractor, model = load_model() |
|
|
predictions = predict(image, extractor, model) |
|
|
|
|
|
st.subheader("Modelgæt") |
|
|
for label, prob in predictions: |
|
|
st.write(f"{label}: {prob:.2%}") |
|
|
|
|
|
if predictions[0][1] < 0.70: |
|
|
st.warning("Sikkerheden er lav. Vælg venligst fødevare manuelt.") |
|
|
fallback = st.selectbox("Vælg madtype", options=[label for label, _ in predictions] + ["Andet"]) |
|
|
st.write(f"Valgt: {fallback}") |
|
|
else: |
|
|
st.success(f"Detekteret: {predictions[0][0]}") |
|
|
|
|
|
feedback = st.radio("Er dette korrekt?", ["Ja", "Nej", "Ved ikke"]) |
|
|
if st.button("Send feedback"): |
|
|
st.info(f"Tak for din feedback: {feedback}") |
|
|
|