File size: 2,801 Bytes
c399765
 
857d4f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be2ed92
857d4f5
 
398cb92
857d4f5
be2ed92
857d4f5
7a29d91
 
 
 
 
 
 
 
c399765
be2ed92
 
 
7a29d91
857d4f5
 
 
 
 
 
be2ed92
 
 
398cb92
857d4f5
6983cce
be2ed92
6983cce
be2ed92
857d4f5
7a29d91
 
 
 
 
 
 
c399765
be2ed92
 
 
7a29d91
 
 
 
 
 
c399765
be2ed92
 
 
7a29d91
857d4f5
 
 
 
 
 
 
 
 
 
be2ed92
857d4f5
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gc

import torch
from detoxify import Detoxify
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline

from app.core.config import TR_OFF_MODEL_PATH

_STATE = {
    "T_O": None,
    "M_O": None,
    "GB_PIPE": None,
    "D_EN": None,
    "D_MULTI": None,
    "TORCH_DEVICE": None,
}


def load_system():
    if _STATE["T_O"] is not None:
        return _STATE

    device_id = 0 if torch.cuda.is_available() else -1
    torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[LOAD] Device: {torch_device}")

    tokenizer_o = AutoTokenizer.from_pretrained(TR_OFF_MODEL_PATH)
    model_o = AutoModelForSequenceClassification.from_pretrained(TR_OFF_MODEL_PATH).to(torch_device)
    model_o.eval()
    print("[LOAD] BERTurk yüklendi")

    if torch_device.type == "cpu":
        try:
            model_o = torch.quantization.quantize_dynamic(
                model_o,
                {torch.nn.Linear},
                dtype=torch.qint8,
            )
            model_o.eval()
            gc.collect()
            print("[LOAD] BERTurk INT8 OK")
        except Exception as e:
            print(f"[LOAD] BERTurk INT8 HATA: {e}")

    try:
        gibberish = pipeline(
            "text-classification",
            model="madhurjindal/autonlp-Gibberish-Detector-492513457",
            device=device_id,
        )
        print("[LOAD] Gibberish yüklendi")
    except Exception as e:
        print(f"[LOAD] Gibberish HATA: {e}")
        gibberish = None

    detox_en = Detoxify("original")
    print("[LOAD] Detoxify EN yüklendi")
    detox_multi = Detoxify("multilingual")
    print("[LOAD] Detoxify Multi yüklendi")

    if torch_device.type == "cpu":
        try:
            detox_en.model = torch.quantization.quantize_dynamic(
                detox_en.model,
                {torch.nn.Linear},
                dtype=torch.qint8,
            )
            gc.collect()
            print("[LOAD] Detoxify EN INT8 OK")
        except Exception as e:
            print(f"[LOAD] Detoxify EN INT8 HATA: {e}")
        try:
            detox_multi.model = torch.quantization.quantize_dynamic(
                detox_multi.model,
                {torch.nn.Linear},
                dtype=torch.qint8,
            )
            gc.collect()
            print("[LOAD] Detoxify Multi INT8 OK")
        except Exception as e:
            print(f"[LOAD] Detoxify Multi INT8 HATA: {e}")

    _STATE.update(
        {
            "T_O": tokenizer_o,
            "M_O": model_o,
            "GB_PIPE": gibberish,
            "D_EN": detox_en,
            "D_MULTI": detox_multi,
            "TORCH_DEVICE": torch_device,
        }
    )
    print("[LOAD] Sistem hazir")

    return _STATE


def get_model_state():
    return load_system()