h7-chat-backend / models_loader.py
hashan-7's picture
Update code
90fdd28 verified
import pickle
from typing import Optional, Tuple
import numpy as np
import requests
import keras
from sentence_transformers import CrossEncoder
TOKENIZER_PATH = "ai_engine/tokenizer.pkl"
LABEL_ENCODER_PATH = "ai_engine/label_encoder.pkl"
INTENT_MODEL_PATH = "ai_engine/h7_intent_model.keras"
RERANKER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
MODEL_NAME = "llama3.2:1b"
OLLAMA_PULL_URL = "http://localhost:11434/api/pull"
OLLAMA_WARMUP_TIMEOUT = 2
tokenizer = None
label_encoder = None
intent_model = None
re_ranker = None
models_loaded = False
def patch_layer(layer_class):
original_init = layer_class.__init__
def patched_init(self, *args, **kwargs):
kwargs.pop("quantization_config", None)
original_init(self, *args, **kwargs)
return patched_init
def apply_keras_patches():
for layer in (
keras.layers.Embedding,
keras.layers.Dense,
keras.layers.LSTM,
keras.layers.Bidirectional,
):
try:
layer.__init__ = patch_layer(layer)
except Exception:
pass
def warmup_ollama_model():
try:
requests.post(
OLLAMA_PULL_URL,
json={"name": MODEL_NAME},
timeout=OLLAMA_WARMUP_TIMEOUT,
)
except Exception:
pass
def load_tokenizer():
with open(TOKENIZER_PATH, "rb") as f:
return pickle.load(f)
def load_label_encoder():
with open(LABEL_ENCODER_PATH, "rb") as f:
return pickle.load(f)
def load_intent_model():
return keras.models.load_model(INTENT_MODEL_PATH)
def load_reranker():
return CrossEncoder(RERANKER_MODEL_NAME)
def reset_models():
global tokenizer, label_encoder, intent_model, re_ranker, models_loaded
tokenizer = None
label_encoder = None
intent_model = None
re_ranker = None
models_loaded = False
def load_all_models(force_reload: bool = False):
global tokenizer, label_encoder, intent_model, re_ranker, models_loaded
if models_loaded and not force_reload:
print("✅ Models already loaded.", flush=True)
return
apply_keras_patches()
try:
tokenizer = load_tokenizer()
label_encoder = load_label_encoder()
intent_model = load_intent_model()
try:
re_ranker = load_reranker()
except Exception as rerank_error:
re_ranker = None
print(f"⚠️ Re-ranker load warning: {rerank_error}", flush=True)
warmup_ollama_model()
models_loaded = True
print("✅ Models Loaded.", flush=True)
except Exception as e:
reset_models()
print(f"❌ Startup Error: {e}", flush=True)
def are_models_ready() -> bool:
return tokenizer is not None and label_encoder is not None and intent_model is not None
def predict_intent(text: str, max_sequence_length: int = 20) -> Tuple[str, float]:
if not are_models_ready():
return "unknown", 0.0
try:
from keras.preprocessing.sequence import pad_sequences
normalized_text = str(text or "").strip()
if not normalized_text:
return "unknown", 0.0
seq = tokenizer.texts_to_sequences([normalized_text])
padded = pad_sequences(seq, maxlen=max_sequence_length)
predictions = intent_model.predict(padded, verbose=0)
confidence = float(np.max(predictions))
intent_index = int(np.argmax(predictions))
intent_label = label_encoder.inverse_transform([intent_index])[0]
return str(intent_label), confidence
except Exception as e:
print(f"Intent Prediction Error: {e}", flush=True)
return "unknown", 0.0