| import gradio as gr |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import torch.nn.functional as F |
|
|
| tokenizer = AutoTokenizer.from_pretrained("tuantc/flysugarbot") |
| model = AutoModelForSequenceClassification.from_pretrained("tuantc/flysugarbot") |
| model.eval() |
|
|
| classes = ['admin_hr_avi_hr_lead', |
| 'airbus_avi_hr_lead', |
| 'avi_hr_lead', |
| 'bot_challenge', |
| 'bus_info', |
| 'c99_avi_hr_lead', |
| 'card', |
| 'channel_info', |
| 'congras_general', |
| 'contract', |
| 'doc_email_sw', |
| 'dxg_avi_hr_lead', |
| 'empl_password', |
| 'encourage', |
| 'excited', |
| 'fly_avi_hr_lead', |
| 'general_spam', |
| 'goodbye', |
| 'greeting', |
| 'laptop_pc', |
| 'lover_spam', |
| 'meeting_room', |
| 'on_leave', |
| 'palantir_avi_hr_lead', |
| 'printer', |
| 'property', |
| 'salary_insurance_avi_hr_lead', |
| 'services', |
| 'spam_eating', |
| 'sw_install', |
| 'thoi_viec', |
| 'tms', |
| 'union', |
| 'wifi', |
| 'working_hour'] |
| NUM_LABELS = len(classes) |
| label_map = {} |
| for l in range(NUM_LABELS): |
| label_map[l] = classes[l] |
|
|
| def predict_intent(model, tokenizer, text_instance,label_map): |
| device = torch.device("cpu") |
| encoding = tokenizer(text_instance, return_tensors='pt', padding=True, truncation=True, max_length=512) |
| input_ids = encoding['input_ids'].to(device) |
| attention_masks = encoding['attention_mask'].to(device) |
| with torch.no_grad(): |
| outputs = model(input_ids, attention_mask=attention_masks) |
| logits = outputs[0] |
| predictions = torch.argmax(logits, dim=-1).item() |
| predicted_label = label_map[predictions] |
| prob = F.softmax(logits, dim=1).max().item() |
| return predicted_label, prob |
| def bot_response(message, history): |
| predicted_intent, prob = predict_intent(model=model, tokenizer=tokenizer, text_instance=message,label_map=label_map) |
| return f"Intent này là {predicted_intent} với confidence {prob:.2f}" |
| demo = gr.ChatInterface(bot_response) |
| demo.launch(share=True) |
|
|