manarsaber11 commited on
Commit
236505b
·
verified ·
1 Parent(s): b716ea0

Update unified_api.py

Browse files
Files changed (1) hide show
  1. unified_api.py +8 -7
unified_api.py CHANGED
@@ -24,7 +24,6 @@ from fastapi.responses import JSONResponse
24
  from fastapi.middleware.cors import CORSMiddleware
25
  from pydantic import BaseModel
26
  from transformers import AutoTokenizer, RobertaModel
27
- from huggingface_hub import hf_hub_download
28
  import torch.nn.functional as F
29
 
30
 
@@ -477,14 +476,16 @@ class DocumentProcessor:
477
  async def lifespan(app: FastAPI):
478
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
479
  logger.info(f"[startup] device = {device}")
480
- logger.info(f"[startup] downloading model from HuggingFace: {Config.HF_REPO_ID}")
481
 
482
- # Download files from HF Hub
483
- model_path = hf_hub_download(repo_id=Config.HF_REPO_ID, filename="model_last.pt")
484
- le_dept_path = hf_hub_download(repo_id=Config.HF_REPO_ID, filename="label_encoder.pkl")
485
- le_prio_path = hf_hub_download(repo_id=Config.HF_REPO_ID, filename="priority_encoder.pkl")
486
 
487
- tokenizer = AutoTokenizer.from_pretrained(Config.HF_REPO_ID)
 
 
 
 
 
 
488
  le_dept = joblib.load(le_dept_path)
489
  le_prio = joblib.load(le_prio_path)
490
 
 
24
  from fastapi.middleware.cors import CORSMiddleware
25
  from pydantic import BaseModel
26
  from transformers import AutoTokenizer, RobertaModel
 
27
  import torch.nn.functional as F
28
 
29
 
 
476
  async def lifespan(app: FastAPI):
477
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
478
  logger.info(f"[startup] device = {device}")
 
479
 
480
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
 
 
 
481
 
482
+ model_path = os.path.join(BASE_DIR, "model_last.pt")
483
+ le_dept_path = os.path.join(BASE_DIR, "label_encoder.pkl")
484
+ le_prio_path = os.path.join(BASE_DIR, "priority_encoder.pkl")
485
+
486
+ logger.info(f"[startup] loading model from local files...")
487
+
488
+ tokenizer = AutoTokenizer.from_pretrained(BASE_DIR)
489
  le_dept = joblib.load(le_dept_path)
490
  le_prio = joblib.load(le_prio_path)
491