9Dome commited on
Commit
84e7748
·
verified ·
1 Parent(s): 86bbcdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -16,15 +16,18 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
16
  logging.basicConfig(level=logging.INFO)
17
  logging.info(f"torch version:\t{torch.__version__}")
18
 
19
- # Model names
20
- checker_model_name = "textattack/roberta-base-CoLA"
21
- corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis"
22
-
23
 
24
  checker = pipeline(
25
  "text-classification",
26
- checker_model_name,
27
- device_map="cuda",
 
 
 
 
 
 
28
  )
29
 
30
  corrector = pipeline(
 
16
  logging.basicConfig(level=logging.INFO)
17
  logging.info(f"torch version:\t{torch.__version__}")
18
 
19
+ device = 0 if torch.cuda.is_available() else -1
 
 
 
20
 
21
  checker = pipeline(
22
  "text-classification",
23
+ model=checker_model_name,
24
+ device=device, # แก้จาก device_map="cuda" เป็น device
25
+ )
26
+
27
+ corrector = pipeline(
28
+ "text2text-generation",
29
+ model=corrector_model_name,
30
+ device=device, # แก้จาก device_map="cuda" เป็น device
31
  )
32
 
33
  corrector = pipeline(