Spaces:
Sleeping
Sleeping
| import gc | |
| import torch | |
| from detoxify import Detoxify | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
| from app.core.config import TR_OFF_MODEL_PATH | |
| _STATE = { | |
| "T_O": None, | |
| "M_O": None, | |
| "GB_PIPE": None, | |
| "D_EN": None, | |
| "D_MULTI": None, | |
| "TORCH_DEVICE": None, | |
| } | |
| def load_system(): | |
| if _STATE["T_O"] is not None: | |
| return _STATE | |
| device_id = 0 if torch.cuda.is_available() else -1 | |
| torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"[LOAD] Device: {torch_device}") | |
| tokenizer_o = AutoTokenizer.from_pretrained(TR_OFF_MODEL_PATH) | |
| model_o = AutoModelForSequenceClassification.from_pretrained(TR_OFF_MODEL_PATH).to(torch_device) | |
| model_o.eval() | |
| print("[LOAD] BERTurk yüklendi") | |
| if torch_device.type == "cpu": | |
| try: | |
| model_o = torch.quantization.quantize_dynamic( | |
| model_o, | |
| {torch.nn.Linear}, | |
| dtype=torch.qint8, | |
| ) | |
| model_o.eval() | |
| gc.collect() | |
| print("[LOAD] BERTurk INT8 OK") | |
| except Exception as e: | |
| print(f"[LOAD] BERTurk INT8 HATA: {e}") | |
| try: | |
| gibberish = pipeline( | |
| "text-classification", | |
| model="madhurjindal/autonlp-Gibberish-Detector-492513457", | |
| device=device_id, | |
| ) | |
| print("[LOAD] Gibberish yüklendi") | |
| except Exception as e: | |
| print(f"[LOAD] Gibberish HATA: {e}") | |
| gibberish = None | |
| detox_en = Detoxify("original") | |
| print("[LOAD] Detoxify EN yüklendi") | |
| detox_multi = Detoxify("multilingual") | |
| print("[LOAD] Detoxify Multi yüklendi") | |
| if torch_device.type == "cpu": | |
| try: | |
| detox_en.model = torch.quantization.quantize_dynamic( | |
| detox_en.model, | |
| {torch.nn.Linear}, | |
| dtype=torch.qint8, | |
| ) | |
| gc.collect() | |
| print("[LOAD] Detoxify EN INT8 OK") | |
| except Exception as e: | |
| print(f"[LOAD] Detoxify EN INT8 HATA: {e}") | |
| try: | |
| detox_multi.model = torch.quantization.quantize_dynamic( | |
| detox_multi.model, | |
| {torch.nn.Linear}, | |
| dtype=torch.qint8, | |
| ) | |
| gc.collect() | |
| print("[LOAD] Detoxify Multi INT8 OK") | |
| except Exception as e: | |
| print(f"[LOAD] Detoxify Multi INT8 HATA: {e}") | |
| _STATE.update( | |
| { | |
| "T_O": tokenizer_o, | |
| "M_O": model_o, | |
| "GB_PIPE": gibberish, | |
| "D_EN": detox_en, | |
| "D_MULTI": detox_multi, | |
| "TORCH_DEVICE": torch_device, | |
| } | |
| ) | |
| print("[LOAD] Sistem hazir") | |
| return _STATE | |
| def get_model_state(): | |
| return load_system() | |