Spaces:
Runtime error
Runtime error
Update Dockerfile
Browse files
app.py
CHANGED
|
@@ -128,33 +128,29 @@ CACHE_DIR = "/cache/models"
|
|
| 128 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 129 |
|
| 130 |
def load_model():
|
| 131 |
-
"""Load model
|
| 132 |
try:
|
| 133 |
-
#
|
| 134 |
-
print("
|
| 135 |
-
|
| 136 |
MODEL_ID,
|
| 137 |
cache_dir=CACHE_DIR,
|
| 138 |
-
local_files_only=True #
|
| 139 |
-
).to(DEVICE)
|
| 140 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 141 |
MODEL_ID,
|
| 142 |
cache_dir=CACHE_DIR,
|
| 143 |
local_files_only=True
|
| 144 |
)
|
| 145 |
except OSError:
|
| 146 |
-
# Fallback to download if cache
|
| 147 |
-
print("
|
| 148 |
-
|
| 149 |
MODEL_ID,
|
| 150 |
cache_dir=CACHE_DIR
|
| 151 |
-
).to(DEVICE)
|
| 152 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 153 |
MODEL_ID,
|
| 154 |
cache_dir=CACHE_DIR
|
| 155 |
)
|
| 156 |
-
|
| 157 |
-
return model, tokenizer
|
| 158 |
|
| 159 |
# Load model
|
| 160 |
model, tokenizer = load_model()
|
|
|
|
| 128 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 129 |
|
| 130 |
def load_model():
|
| 131 |
+
"""Load model with automatic cache handling"""
|
| 132 |
try:
|
| 133 |
+
# First try with local files only (uses cache if available)
|
| 134 |
+
print("Checking for cached model...")
|
| 135 |
+
return AutoModelForCausalLM.from_pretrained(
|
| 136 |
MODEL_ID,
|
| 137 |
cache_dir=CACHE_DIR,
|
| 138 |
+
local_files_only=True # Will fail if not cached
|
| 139 |
+
).to(DEVICE), AutoTokenizer.from_pretrained(
|
|
|
|
| 140 |
MODEL_ID,
|
| 141 |
cache_dir=CACHE_DIR,
|
| 142 |
local_files_only=True
|
| 143 |
)
|
| 144 |
except OSError:
|
| 145 |
+
# Fallback to download if not in cache
|
| 146 |
+
print("Downloading model...")
|
| 147 |
+
return AutoModelForCausalLM.from_pretrained(
|
| 148 |
MODEL_ID,
|
| 149 |
cache_dir=CACHE_DIR
|
| 150 |
+
).to(DEVICE), AutoTokenizer.from_pretrained(
|
|
|
|
| 151 |
MODEL_ID,
|
| 152 |
cache_dir=CACHE_DIR
|
| 153 |
)
|
|
|
|
|
|
|
| 154 |
|
| 155 |
# Load model
|
| 156 |
model, tokenizer = load_model()
|