Spaces:
Sleeping
Sleeping
Update unified_api.py
Browse files- unified_api.py +8 -7
unified_api.py
CHANGED
|
@@ -24,7 +24,6 @@ from fastapi.responses import JSONResponse
|
|
| 24 |
from fastapi.middleware.cors import CORSMiddleware
|
| 25 |
from pydantic import BaseModel
|
| 26 |
from transformers import AutoTokenizer, RobertaModel
|
| 27 |
-
from huggingface_hub import hf_hub_download
|
| 28 |
import torch.nn.functional as F
|
| 29 |
|
| 30 |
|
|
@@ -477,14 +476,16 @@ class DocumentProcessor:
|
|
| 477 |
async def lifespan(app: FastAPI):
|
| 478 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 479 |
logger.info(f"[startup] device = {device}")
|
| 480 |
-
logger.info(f"[startup] downloading model from HuggingFace: {Config.HF_REPO_ID}")
|
| 481 |
|
| 482 |
-
|
| 483 |
-
model_path = hf_hub_download(repo_id=Config.HF_REPO_ID, filename="model_last.pt")
|
| 484 |
-
le_dept_path = hf_hub_download(repo_id=Config.HF_REPO_ID, filename="label_encoder.pkl")
|
| 485 |
-
le_prio_path = hf_hub_download(repo_id=Config.HF_REPO_ID, filename="priority_encoder.pkl")
|
| 486 |
|
| 487 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
le_dept = joblib.load(le_dept_path)
|
| 489 |
le_prio = joblib.load(le_prio_path)
|
| 490 |
|
|
|
|
| 24 |
from fastapi.middleware.cors import CORSMiddleware
|
| 25 |
from pydantic import BaseModel
|
| 26 |
from transformers import AutoTokenizer, RobertaModel
|
|
|
|
| 27 |
import torch.nn.functional as F
|
| 28 |
|
| 29 |
|
|
|
|
| 476 |
async def lifespan(app: FastAPI):
|
| 477 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 478 |
logger.info(f"[startup] device = {device}")
|
|
|
|
| 479 |
|
| 480 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
|
|
|
|
| 481 |
|
| 482 |
+
model_path = os.path.join(BASE_DIR, "model_last.pt")
|
| 483 |
+
le_dept_path = os.path.join(BASE_DIR, "label_encoder.pkl")
|
| 484 |
+
le_prio_path = os.path.join(BASE_DIR, "priority_encoder.pkl")
|
| 485 |
+
|
| 486 |
+
logger.info(f"[startup] loading model from local files...")
|
| 487 |
+
|
| 488 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_DIR)
|
| 489 |
le_dept = joblib.load(le_dept_path)
|
| 490 |
le_prio = joblib.load(le_prio_path)
|
| 491 |
|