Update app.py
Browse files
app.py
CHANGED
|
@@ -16,15 +16,18 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
| 16 |
logging.basicConfig(level=logging.INFO)
|
| 17 |
logging.info(f"torch version:\t{torch.__version__}")
|
| 18 |
|
| 19 |
-
|
| 20 |
-
checker_model_name = "textattack/roberta-base-CoLA"
|
| 21 |
-
corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis"
|
| 22 |
-
|
| 23 |
|
| 24 |
checker = pipeline(
|
| 25 |
"text-classification",
|
| 26 |
-
checker_model_name,
|
| 27 |
-
device_map="cuda"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
)
|
| 29 |
|
| 30 |
corrector = pipeline(
|
|
|
|
| 16 |
logging.basicConfig(level=logging.INFO)
|
| 17 |
logging.info(f"torch version:\t{torch.__version__}")
|
| 18 |
|
| 19 |
+
device = 0 if torch.cuda.is_available() else -1
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
checker = pipeline(
|
| 22 |
"text-classification",
|
| 23 |
+
model=checker_model_name,
|
| 24 |
+
device=device, # แก้จาก device_map="cuda" เป็น device
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
corrector = pipeline(
|
| 28 |
+
"text2text-generation",
|
| 29 |
+
model=corrector_model_name,
|
| 30 |
+
device=device, # แก้จาก device_map="cuda" เป็น device
|
| 31 |
)
|
| 32 |
|
| 33 |
corrector = pipeline(
|