File size: 722 Bytes
9f452a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import os
import requests
import torch
import streamlit as st

def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def download_checkpoint(url, filename):
    if os.path.exists(filename):
        st.info(f"Checkpoint già presente: {filename}")
        return
    st.warning(f"Scaricamento checkpoint SAM in corso... ({filename})")
    resp = requests.get(url, stream=True)
    total = int(resp.headers.get("content-length", 0))
    dl = 0
    bar = st.progress(0)
    with open(filename, "wb") as f:
        for chunk in resp.iter_content(1_048_576):
            f.write(chunk); dl += len(chunk)
            bar.progress(int(dl / total * 100))
    st.success("Download completato!")