Juna190825 commited on
Commit
3c3fb04
·
verified ·
1 Parent(s): d8d0f11

Update Dockerfile

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -128,33 +128,29 @@ CACHE_DIR = "/cache/models"
128
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
129
 
130
  def load_model():
131
- """Load model directly, attempting cache first"""
132
  try:
133
- # Try loading from cache
134
- print("Attempting to load from cache...")
135
- model = AutoModelForCausalLM.from_pretrained(
136
  MODEL_ID,
137
  cache_dir=CACHE_DIR,
138
- local_files_only=True # Force cache usage
139
- ).to(DEVICE)
140
- tokenizer = AutoTokenizer.from_pretrained(
141
  MODEL_ID,
142
  cache_dir=CACHE_DIR,
143
  local_files_only=True
144
  )
145
  except OSError:
146
- # Fallback to download if cache missing
147
- print("Cache not found, downloading...")
148
- model = AutoModelForCausalLM.from_pretrained(
149
  MODEL_ID,
150
  cache_dir=CACHE_DIR
151
- ).to(DEVICE)
152
- tokenizer = AutoTokenizer.from_pretrained(
153
  MODEL_ID,
154
  cache_dir=CACHE_DIR
155
  )
156
-
157
- return model, tokenizer
158
 
159
  # Load model
160
  model, tokenizer = load_model()
 
128
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
129
 
130
  def load_model():
131
+ """Load model with automatic cache handling"""
132
  try:
133
+ # First try with local files only (uses cache if available)
134
+ print("Checking for cached model...")
135
+ return AutoModelForCausalLM.from_pretrained(
136
  MODEL_ID,
137
  cache_dir=CACHE_DIR,
138
+ local_files_only=True # Will fail if not cached
139
+ ).to(DEVICE), AutoTokenizer.from_pretrained(
 
140
  MODEL_ID,
141
  cache_dir=CACHE_DIR,
142
  local_files_only=True
143
  )
144
  except OSError:
145
+ # Fallback to download if not in cache
146
+ print("Downloading model...")
147
+ return AutoModelForCausalLM.from_pretrained(
148
  MODEL_ID,
149
  cache_dir=CACHE_DIR
150
+ ).to(DEVICE), AutoTokenizer.from_pretrained(
 
151
  MODEL_ID,
152
  cache_dir=CACHE_DIR
153
  )
 
 
154
 
155
  # Load model
156
  model, tokenizer = load_model()