sentinel-api / app /ml /model_loader.py
Mustafa Öztürk
Use remote Detoxify loading in sentinel-api
6983cce
import gc
import torch
from detoxify import Detoxify
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from app.core.config import TR_OFF_MODEL_PATH
_STATE = {
"T_O": None,
"M_O": None,
"GB_PIPE": None,
"D_EN": None,
"D_MULTI": None,
"TORCH_DEVICE": None,
}
def load_system():
if _STATE["T_O"] is not None:
return _STATE
device_id = 0 if torch.cuda.is_available() else -1
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[LOAD] Device: {torch_device}")
tokenizer_o = AutoTokenizer.from_pretrained(TR_OFF_MODEL_PATH)
model_o = AutoModelForSequenceClassification.from_pretrained(TR_OFF_MODEL_PATH).to(torch_device)
model_o.eval()
print("[LOAD] BERTurk yüklendi")
if torch_device.type == "cpu":
try:
model_o = torch.quantization.quantize_dynamic(
model_o,
{torch.nn.Linear},
dtype=torch.qint8,
)
model_o.eval()
gc.collect()
print("[LOAD] BERTurk INT8 OK")
except Exception as e:
print(f"[LOAD] BERTurk INT8 HATA: {e}")
try:
gibberish = pipeline(
"text-classification",
model="madhurjindal/autonlp-Gibberish-Detector-492513457",
device=device_id,
)
print("[LOAD] Gibberish yüklendi")
except Exception as e:
print(f"[LOAD] Gibberish HATA: {e}")
gibberish = None
detox_en = Detoxify("original")
print("[LOAD] Detoxify EN yüklendi")
detox_multi = Detoxify("multilingual")
print("[LOAD] Detoxify Multi yüklendi")
if torch_device.type == "cpu":
try:
detox_en.model = torch.quantization.quantize_dynamic(
detox_en.model,
{torch.nn.Linear},
dtype=torch.qint8,
)
gc.collect()
print("[LOAD] Detoxify EN INT8 OK")
except Exception as e:
print(f"[LOAD] Detoxify EN INT8 HATA: {e}")
try:
detox_multi.model = torch.quantization.quantize_dynamic(
detox_multi.model,
{torch.nn.Linear},
dtype=torch.qint8,
)
gc.collect()
print("[LOAD] Detoxify Multi INT8 OK")
except Exception as e:
print(f"[LOAD] Detoxify Multi INT8 HATA: {e}")
_STATE.update(
{
"T_O": tokenizer_o,
"M_O": model_o,
"GB_PIPE": gibberish,
"D_EN": detox_en,
"D_MULTI": detox_multi,
"TORCH_DEVICE": torch_device,
}
)
print("[LOAD] Sistem hazir")
return _STATE
def get_model_state():
return load_system()