image_captioner / src /streamlit_app.py
mevlt01001's picture
Update src/streamlit_app.py
fcc2785 verified
# app.py
import streamlit as st
import tempfile
from PIL import Image
import torch
from modules.model import Model
from modules.tokenizer import Tokenizer
st.set_page_config(page_title="IMG-Captioning", layout="centered")
@st.cache_resource(ttl=3600) # model nesnesini bellekte tutar; Streamlit 1.18+ için
def load_model(checkpoint_path: str, device: torch.device):
tokenizer = Tokenizer(["b","a"])
model = Model.load_from_checkpoint(
path=checkpoint_path,
tokenizer=tokenizer,
freeze_backbone=True,
device=device
)
model.eval()
return model
def run_predict(model, img_path: str, maxlen: int = 256):
# predict'in içindeki torch.no_grad() yoksa burada eklemek iyi olur
with torch.no_grad():
return model.predict(img_path, maxlen=maxlen)
def main():
st.title("Görüntü -> Model Predict (Streamlit)")
st.markdown("Upload a single image and press **Predict**. Model yüklü GPU yoksa CPU'da çalışır.")
checkpoint_path = st.text_input("Checkpoint path", value="Trained_pt/best.pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.info(f"Model device: {device}")
model = load_model(checkpoint_path, device)
uploaded_file = st.file_uploader("Upload an image", type=["png","jpg","jpeg","bmp"])
maxlen = st.slider("maxlen", min_value=16, max_value=1024, value=256, step=16)
if uploaded_file is not None:
# Preview image
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded image", use_column_width=True)
if st.button("Predict"):
# Temp file'a kaydet, modelin path beklediğini varsayıyoruz
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
image.save(tmp.name)
tmp_path = tmp.name
try:
with st.spinner("Model çalıştırılıyor..."):
result = run_predict(model, tmp_path, maxlen=maxlen)
st.success("Predict tamamlandı.")
except Exception as e:
st.error(f"Predict sırasında hata: {e}")
return
# Sonucu kullanıcıya göster
st.subheader("Model çıktısı")
# result'in tipi bilinmediği için genel gösterim:
if isinstance(result, (str, int, float)):
st.write(result)
elif isinstance(result, dict):
for k, v in result.items():
st.write(f"**{k}**: {v}")
elif isinstance(result, (list, tuple)):
st.write(result)
else:
# fallback: raw repr
st.write(repr(result))
if __name__ == "__main__":
main()