bilstm / app.py
neyugntuan's picture
Upload app.py
37104be verified
# -*- coding: utf-8 -*-
"""app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1FHgZJGBHIJ2C2OvYebeuCqVw1F7Rc29k
"""
from fastapi import FastAPI
from pydantic import BaseModel
import re
from underthesea import word_tokenize
import tensorflow as tf
from tensorflow.keras.layers import InputLayer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import pickle
app = FastAPI()
# ---- PATCH InputLayer ----
_original_init = InputLayer.__init__
def _patched_init(self, *args, **kwargs):
if "batch_shape" in kwargs and "batch_input_shape" not in kwargs:
kwargs["batch_input_shape"] = kwargs.pop("batch_shape")
return _original_init(self, *args, **kwargs)
InputLayer.__init__ = _patched_init
# ---- END PATCH ----
"""1. Load mô hình + tokenizer"""
# Load model
import tensorflow as tf
model = tf.keras.models.load_model("bilstm_model.h5")
# Load tokenizer
import pickle
with open("tokenizer.pkl", "rb") as f:
tokenizer = pickle.load(f)
"""2. Làm sạch dữ liệu"""
import re
def clean_text(text):
text = text.lower()
text = re.sub(r"http\S+", "", text)
text = re.sub(r"[^a-zA-Z0-9áàảãạăắằẳẵặâấầẩẫậéèẻẽẹêếềểễệíìỉĩịóòỏõọôốồổỗộơớờởỡợúùủũụưứừửữựýỳỷỹỵđ ]", " ", text)
text = re.sub(r"\s+", " ", text).strip()
text = word_tokenize(text, format="text")
return text
class PredictRequest(BaseModel):
text: str
class PredictResponse(BaseModel):
label: int
prob: float
"""2.5. Hàm test api"""
@app.post("/health")
def health():
return {"status": "ok"}
"""3. Hàm dự đoán"""
max_len = 300
@app.post("/predict", response_model=PredictResponse)
def predict(req: PredictRequest):
# 1. Clean
cleaned = clean_text(req.text)
# 2. Tokenize
seq = tokenizer.texts_to_sequences([cleaned])
# 3. Pad
padded = pad_sequences(seq, maxlen=max_len)
# 4. Predict
prob = model.predict(padded)[0][0] #khi dự báo, kết quả trả về là 1 ma trận dù chỉ có 1 phần tử :))? in ra giá trị đầu của ma trận là có thể dùng như số(if else)
label = 1 if prob > 0.5 else 0
return PredictResponse(label=label, prob=prob)