BoostedJonP commited on
Commit
9c0c216
·
1 Parent(s): 6f78921

auto config

Browse files
Files changed (1) hide show
  1. app.py +12 -1
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from functools import lru_cache
5
  import logging
6
 
@@ -20,6 +20,12 @@ def load_model():
20
  logger.info(f"Loading model: {MODEL_NAME}")
21
 
22
  try:
 
 
 
 
 
 
23
  tokenizer = AutoTokenizer.from_pretrained(
24
  MODEL_NAME,
25
  trust_remote_code=True,
@@ -31,6 +37,7 @@ def load_model():
31
  logger.info("CUDA available, loading with GPU optimizations")
32
  model = AutoModelForCausalLM.from_pretrained(
33
  MODEL_NAME,
 
34
  trust_remote_code=True,
35
  torch_dtype=torch.float16,
36
  device_map="auto",
@@ -40,8 +47,12 @@ def load_model():
40
  )
41
  else:
42
  logger.info("CUDA not available, loading with CPU optimizations")
 
 
 
43
  model = AutoModelForCausalLM.from_pretrained(
44
  MODEL_NAME,
 
45
  trust_remote_code=True,
46
  torch_dtype=torch.float32,
47
  attn_implementation="eager",
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
4
  from functools import lru_cache
5
  import logging
6
 
 
20
  logger.info(f"Loading model: {MODEL_NAME}")
21
 
22
  try:
23
+ config = AutoConfig.from_pretrained(
24
+ MODEL_NAME,
25
+ trust_remote_code=True,
26
+ cache_dir="/tmp/model_cache",
27
+ )
28
+
29
  tokenizer = AutoTokenizer.from_pretrained(
30
  MODEL_NAME,
31
  trust_remote_code=True,
 
37
  logger.info("CUDA available, loading with GPU optimizations")
38
  model = AutoModelForCausalLM.from_pretrained(
39
  MODEL_NAME,
40
+ config=config,
41
  trust_remote_code=True,
42
  torch_dtype=torch.float16,
43
  device_map="auto",
 
47
  )
48
  else:
49
  logger.info("CUDA not available, loading with CPU optimizations")
50
+ if getattr(config, "quantization_config", None) is not None:
51
+ logger.info("Disabling quantization settings for CPU execution")
52
+ config.quantization_config = None
53
  model = AutoModelForCausalLM.from_pretrained(
54
  MODEL_NAME,
55
+ config=config,
56
  trust_remote_code=True,
57
  torch_dtype=torch.float32,
58
  attn_implementation="eager",