BoostedJonP commited on
Commit
5119833
·
1 Parent(s): 8b45180

additional cpu fallback

Browse files
Files changed (2) hide show
  1. app.py +26 -10
  2. requirements.txt +0 -1
app.py CHANGED
@@ -40,16 +40,32 @@ def load_model():
40
  )
41
  else:
42
  logger.info("CUDA not available, loading with CPU optimizations")
43
- model = AutoModelForCausalLM.from_pretrained(
44
- MODEL_NAME,
45
- trust_remote_code=True,
46
- torch_dtype=torch.float32, # Use float32 for CPU
47
- device_map="cpu", # Explicitly set to CPU
48
- attn_implementation="eager",
49
- use_cache=True,
50
- cache_dir="/tmp/model_cache",
51
- low_cpu_mem_usage=True, # Helpful for CPU environments
52
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  logger.info("Model loaded successfully!")
55
  except Exception as e:
 
40
  )
41
  else:
42
  logger.info("CUDA not available, loading with CPU optimizations")
43
+ try:
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ MODEL_NAME,
46
+ trust_remote_code=True,
47
+ torch_dtype=torch.float32, # Use float32 for CPU
48
+ device_map="cpu", # Explicitly set to CPU
49
+ attn_implementation="eager",
50
+ use_cache=True,
51
+ cache_dir="/tmp/model_cache",
52
+ low_cpu_mem_usage=True, # Helpful for CPU environments
53
+ )
54
+ except Exception as cpu_error:
55
+ logger.warning(f"CPU loading failed with device_map: {cpu_error}")
56
+ # Fallback: try without device_map
57
+ logger.info("Trying fallback CPU loading without device_map")
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ MODEL_NAME,
60
+ trust_remote_code=True,
61
+ torch_dtype=torch.float32,
62
+ attn_implementation="eager",
63
+ use_cache=True,
64
+ cache_dir="/tmp/model_cache",
65
+ low_cpu_mem_usage=True,
66
+ )
67
+ # Move model to CPU manually
68
+ model = model.to("cpu")
69
 
70
  logger.info("Model loaded successfully!")
71
  except Exception as e:
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
  torch>=2.0.0,<2.3.0
2
  transformers==4.48.0
3
  accelerate>=0.20.0
4
- bitsandbytes>=0.41.0
5
  gradio>=4.0.0,<5.0.0
6
  safetensors>=0.4.0
 
1
  torch>=2.0.0,<2.3.0
2
  transformers==4.48.0
3
  accelerate>=0.20.0
 
4
  gradio>=4.0.0,<5.0.0
5
  safetensors>=0.4.0