Spaces:
Paused
Paused
Update app/models.py
Browse files- app/models.py +3 -3
app/models.py
CHANGED
|
@@ -3,7 +3,7 @@ import logging
|
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Optional
|
| 5 |
from pydantic import BaseModel
|
| 6 |
-
from transformers import AutoTokenizer,
|
| 7 |
|
| 8 |
import torch
|
| 9 |
|
|
@@ -38,7 +38,7 @@ class DataLocation(BaseModel):
|
|
| 38 |
self.cloud_uri, token=AUTH_TOKEN
|
| 39 |
)
|
| 40 |
|
| 41 |
-
model =
|
| 42 |
self.cloud_uri, token=AUTH_TOKEN
|
| 43 |
)
|
| 44 |
# Save the model and tokenizer locally
|
|
@@ -73,7 +73,7 @@ class QAModel:
|
|
| 73 |
|
| 74 |
# Load the tokenizer and model from the local path
|
| 75 |
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 76 |
-
self.model =
|
| 77 |
logger.info(f"Loaded QA model: {self.model_name}")
|
| 78 |
|
| 79 |
def inference_qa(self, context: str, question: str):
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Optional
|
| 5 |
from pydantic import BaseModel
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
| 7 |
|
| 8 |
import torch
|
| 9 |
|
|
|
|
| 38 |
self.cloud_uri, token=AUTH_TOKEN
|
| 39 |
)
|
| 40 |
|
| 41 |
+
model = AutoModelForQuestionAnswering.from_pretrained(
|
| 42 |
self.cloud_uri, token=AUTH_TOKEN
|
| 43 |
)
|
| 44 |
# Save the model and tokenizer locally
|
|
|
|
| 73 |
|
| 74 |
# Load the tokenizer and model from the local path
|
| 75 |
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 76 |
+
self.model = AutoModelForQuestionAnswering.from_pretrained(model_path,return_dict=True).to(self.device)
|
| 77 |
logger.info(f"Loaded QA model: {self.model_name}")
|
| 78 |
|
| 79 |
def inference_qa(self, context: str, question: str):
|