AI_API / features /text_classifier /model_loader.py
Pujan-Dev's picture
Added the files adding system and formated the files system
0117df3
raw
history blame
1.49 kB
import os
import shutil
import logging
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
from huggingface_hub import snapshot_download
import torch
from dotenv import load_dotenv
load_dotenv()
REPO_ID = "Pujan-Dev/AI-Text-Detector"
MODEL_DIR = "./models"
TOKENIZER_DIR = os.path.join(MODEL_DIR, "model")
WEIGHTS_PATH = os.path.join(MODEL_DIR, "model_weights.pth")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_model, _tokenizer = None, None
def warmup():
global _model, _tokenizer
download_model_repo()
_model, _tokenizer = load_model()
logging.info("Its ready")
def download_model_repo():
if os.path.exists(MODEL_DIR) and os.path.isdir(MODEL_DIR):
logging.info("Model already exists, skipping download.")
return
snapshot_path = snapshot_download(repo_id=REPO_ID)
os.makedirs(MODEL_DIR, exist_ok=True)
shutil.copytree(snapshot_path, MODEL_DIR, dirs_exist_ok=True)
def load_model():
tokenizer = GPT2TokenizerFast.from_pretrained(TOKENIZER_DIR)
config = GPT2Config.from_pretrained(TOKENIZER_DIR)
model = GPT2LMHeadModel(config)
model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device))
model.to(device)
model.eval()
return model, tokenizer
def get_model_tokenizer():
global _model, _tokenizer
if _model is None or _tokenizer is None:
download_model_repo()
_model, _tokenizer = load_model()
return _model, _tokenizer