# app.py import streamlit as st import tempfile from PIL import Image import torch from modules.model import Model from modules.tokenizer import Tokenizer st.set_page_config(page_title="IMG-Captioning", layout="centered") @st.cache_resource(ttl=3600) # model nesnesini bellekte tutar; Streamlit 1.18+ için def load_model(checkpoint_path: str, device: torch.device): tokenizer = Tokenizer(["b","a"]) model = Model.load_from_checkpoint( path=checkpoint_path, tokenizer=tokenizer, freeze_backbone=True, device=device ) model.eval() return model def run_predict(model, img_path: str, maxlen: int = 256): # predict'in içindeki torch.no_grad() yoksa burada eklemek iyi olur with torch.no_grad(): return model.predict(img_path, maxlen=maxlen) def main(): st.title("Görüntü -> Model Predict (Streamlit)") st.markdown("Upload a single image and press **Predict**. Model yüklü GPU yoksa CPU'da çalışır.") checkpoint_path = st.text_input("Checkpoint path", value="Trained_pt/best.pt") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") st.info(f"Model device: {device}") model = load_model(checkpoint_path, device) uploaded_file = st.file_uploader("Upload an image", type=["png","jpg","jpeg","bmp"]) maxlen = st.slider("maxlen", min_value=16, max_value=1024, value=256, step=16) if uploaded_file is not None: # Preview image image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded image", use_column_width=True) if st.button("Predict"): # Temp file'a kaydet, modelin path beklediğini varsayıyoruz with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: image.save(tmp.name) tmp_path = tmp.name try: with st.spinner("Model çalıştırılıyor..."): result = run_predict(model, tmp_path, maxlen=maxlen) st.success("Predict tamamlandı.") except Exception as e: st.error(f"Predict sırasında hata: {e}") return # Sonucu kullanıcıya göster st.subheader("Model çıktısı") # result'in tipi bilinmediği için genel gösterim: if isinstance(result, (str, int, float)): st.write(result) elif isinstance(result, dict): for k, v in result.items(): st.write(f"**{k}**: {v}") elif isinstance(result, (list, tuple)): st.write(result) else: # fallback: raw repr st.write(repr(result)) if __name__ == "__main__": main()