File size: 1,805 Bytes
0bea9f7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import BertTokenizer, TFBertModel
import keras
import traceback
class Prediction(BaseModel):
text: str
app = FastAPI()
tokenizer = BertTokenizer.from_pretrained('./src/assets/')
model = keras.models.load_model('./src/model/scam_class.h5',custom_objects={"TFBertModel": TFBertModel})
@app.get("/")
async def root():
text = "Mainkan terus games mu sekarang dan beli koinnya pakai pulsamu, cek caranya di http://tsel.me/jajanonline"
encoded = tokenizer(text=text,add_special_tokens=True,max_length=50,padding='max_length',
truncation=True,return_tensors='tf',return_token_type_ids=False,verbose=True,return_attention_mask=True)
input_obj = {'input_ids': encoded['input_ids'], 'attention_mask': encoded['attention_mask']}
prediction = model.predict(input_obj)
pred_arr = prediction.tolist()
output = {
"neutral": prediction[0][0],
"scam": prediction[0][1],
"spam": prediction[0][2]
}
return {"result":pred_arr}
@app.post("/predict")
async def predict(data: Prediction):
try:
text = data.text
encoded = tokenizer(text=text,add_special_tokens=True,max_length=50,padding='max_length',
truncation=True,return_tensors='tf',return_token_type_ids=False,verbose=True,return_attention_mask=True)
input_obj = {'input_ids': encoded['input_ids'], 'attention_mask': encoded['attention_mask']}
prediction = model.predict(input_obj)
pred_arr = prediction.tolist()
return {"neutral": pred_arr[0][0], "scam": pred_arr[0][1], "spam": pred_arr[0][2]}
except:
traceback.print_exc()
raise HTTPException(status_code=500, detail="Something went wrong")
|