Spaces:
Sleeping
Sleeping
| # app.py | |
| import streamlit as st | |
| import tempfile | |
| from PIL import Image | |
| import torch | |
| from modules.model import Model | |
| from modules.tokenizer import Tokenizer | |
| st.set_page_config(page_title="IMG-Captioning", layout="centered") | |
| # model nesnesini bellekte tutar; Streamlit 1.18+ için | |
| def load_model(checkpoint_path: str, device: torch.device): | |
| tokenizer = Tokenizer(["b","a"]) | |
| model = Model.load_from_checkpoint( | |
| path=checkpoint_path, | |
| tokenizer=tokenizer, | |
| freeze_backbone=True, | |
| device=device | |
| ) | |
| model.eval() | |
| return model | |
| def run_predict(model, img_path: str, maxlen: int = 256): | |
| # predict'in içindeki torch.no_grad() yoksa burada eklemek iyi olur | |
| with torch.no_grad(): | |
| return model.predict(img_path, maxlen=maxlen) | |
| def main(): | |
| st.title("Görüntü -> Model Predict (Streamlit)") | |
| st.markdown("Upload a single image and press **Predict**. Model yüklü GPU yoksa CPU'da çalışır.") | |
| checkpoint_path = st.text_input("Checkpoint path", value="Trained_pt/best.pt") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| st.info(f"Model device: {device}") | |
| model = load_model(checkpoint_path, device) | |
| uploaded_file = st.file_uploader("Upload an image", type=["png","jpg","jpeg","bmp"]) | |
| maxlen = st.slider("maxlen", min_value=16, max_value=1024, value=256, step=16) | |
| if uploaded_file is not None: | |
| # Preview image | |
| image = Image.open(uploaded_file).convert("RGB") | |
| st.image(image, caption="Uploaded image", use_column_width=True) | |
| if st.button("Predict"): | |
| # Temp file'a kaydet, modelin path beklediğini varsayıyoruz | |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: | |
| image.save(tmp.name) | |
| tmp_path = tmp.name | |
| try: | |
| with st.spinner("Model çalıştırılıyor..."): | |
| result = run_predict(model, tmp_path, maxlen=maxlen) | |
| st.success("Predict tamamlandı.") | |
| except Exception as e: | |
| st.error(f"Predict sırasında hata: {e}") | |
| return | |
| # Sonucu kullanıcıya göster | |
| st.subheader("Model çıktısı") | |
| # result'in tipi bilinmediği için genel gösterim: | |
| if isinstance(result, (str, int, float)): | |
| st.write(result) | |
| elif isinstance(result, dict): | |
| for k, v in result.items(): | |
| st.write(f"**{k}**: {v}") | |
| elif isinstance(result, (list, tuple)): | |
| st.write(result) | |
| else: | |
| # fallback: raw repr | |
| st.write(repr(result)) | |
| if __name__ == "__main__": | |
| main() | |