entropy25 commited on
Commit
1f33fbe
·
verified ·
1 Parent(s): 4d25cd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -31,9 +31,12 @@ MAX_LENGTH = 256
31
 
32
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
33
 
 
 
34
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
35
  BASE_MODEL,
36
- quantization_config=BitsAndBytesConfig(load_in_8bit=True),
 
37
  low_cpu_mem_usage=True,
38
  )
39
 
@@ -91,7 +94,7 @@ def translate_cached(text, source_lang, target_lang):
91
 
92
  adapter_name, tgt_code = config
93
  start = time.time()
94
- device = next(model.parameters()).device
95
  translated_paragraphs = []
96
 
97
  for paragraph in text.split("\n"):
@@ -111,7 +114,7 @@ def translate_cached(text, source_lang, target_lang):
111
  truncation=True,
112
  max_length=MAX_LENGTH,
113
  )
114
- inputs = {k: v.to(device) for k, v in inputs.items()}
115
 
116
  with adapter_lock:
117
  model.set_adapter(adapter_name)
 
31
 
32
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
33
 
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
37
  BASE_MODEL,
38
+ quantization_config=BitsAndBytesConfig(load_in_8bit=True) if device == "cuda" else None,
39
+ device_map={"": 0} if device == "cuda" else None,
40
  low_cpu_mem_usage=True,
41
  )
42
 
 
94
 
95
  adapter_name, tgt_code = config
96
  start = time.time()
97
+ dev = next(model.parameters()).device
98
  translated_paragraphs = []
99
 
100
  for paragraph in text.split("\n"):
 
114
  truncation=True,
115
  max_length=MAX_LENGTH,
116
  )
117
+ inputs = {k: v.to(dev) for k, v in inputs.items()}
118
 
119
  with adapter_lock:
120
  model.set_adapter(adapter_name)