Spaces:
Runtime error
Runtime error
| # import py_vncorenlp | |
| # import gradio as gr | |
| # import os | |
| # import shutil | |
| # from sentence_transformers import CrossEncoder | |
| # save_dir = './vncorenlp' | |
| # models_dir = os.path.join(save_dir, 'models') | |
| # #if os.path.exists(models_dir): | |
| # #j shutil.rmtree(models_dir) | |
| # # print("[DEBUG]: Delete model") | |
| # #print("[DEBUG]: Tao lai folder model") | |
| # #os.makedirs(save_dir + "/models", exist_ok=True) | |
| # print("[DEBUG]: Download model") | |
| # py_vncorenlp.download_model(save_dir=save_dir+'/') | |
| # print("[DEBUG]: Downdload model complete!") | |
| # #py_vncorenlp.download_model(save_dir='/absolute/path/to/vncorenlp') | |
| # print("[DEBUG] rdsegmenter setep") | |
| # rdrsegmenter = py_vncorenlp.VnCoreNLP(annotators=["wseg"], save_dir=save_dir) | |
| # def rerank(query,sentences): | |
| # print("[DEBUG]: Start rerank function...") | |
| # tokenized_query = rdrsegmenter.word_segment(query) | |
| # tokenized_sentences = [rdrsegmenter.word_segment(sent) for sent in sentences] | |
| # tokenized_pairs = [[tokenized_query, sent] for sent in tokenized_sentences] | |
| # MODEL_ID = 'itdainb/PhoRanker' | |
| # MAX_LENGTH = 512 | |
| # model = CrossEncoder(MODEL_ID, max_length=MAX_LENGTH) | |
| # # For fp16 usage | |
| # model.model.half() | |
| # scores = model.predict(tokenized_pairs) | |
| # # 0.982, 0.2444, 0.9253 | |
| # #print(scores) | |
| # return scores | |
| # # Create Gradio interface | |
| # interface = gr.Interface( | |
| # fn=rerank, | |
| # inputs=[ | |
| # gr.Textbox(label="Query", placeholder="Enter your query"), | |
| # gr.Textbox(label="Documents (one per line)", lines=5, placeholder="Enter documents to rank"), | |
| # ], | |
| # outputs=gr.Textbox(label="Reranked Documents"), | |
| # title="MonoT5 Reranking", | |
| # description="Provide a query and a list of documents to rerank them using MonoT5." | |
| # ) | |
| # # Launch the app | |
| # if __name__ == "__main__": | |
| # interface.launch() | |
| from sentence_transformers import CrossEncoder | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import py_vncorenlp | |
| py_vncorenlp.download_model(save_dir='/absolute/path/to/vncorenlp') | |
| rdrsegmenter = py_vncorenlp.VnCoreNLP(annotators=["wseg"], save_dir='/absolute/path/to/vncorenlp') | |
| query = "Trường UIT là gì?" | |
| sentences = [ | |
| "Trường Đại học Công nghệ Thông tin có tên tiếng Anh là University of Information Technology (viết tắt là UIT) là thành viên của Đại học Quốc Gia TP.HCM.", | |
| "Trường Đại học Kinh tế – Luật (tiếng Anh: University of Economics and Law – UEL) là trường đại học đào tạo và nghiên cứu khối ngành kinh tế, kinh doanh và luật hàng đầu Việt Nam.", | |
| "Quĩ uỷ thác đầu tư (tiếng Anh: Unit Investment Trusts; viết tắt: UIT) là một công ty đầu tư mua hoặc nắm giữ một danh mục đầu tư cố định" | |
| ] | |
| tokenized_query = rdrsegmenter.word_segment(query) | |
| tokenized_sentences = [rdrsegmenter.word_segment(sent) for sent in sentences] | |
| tokenized_pairs = [[tokenized_query, sent] for sent in tokenized_sentences] | |
| MODEL_ID = 'itdainb/PhoRanker' | |
| MAX_LENGTH = 256 | |
| model = CrossEncoder(MODEL_ID, max_length=MAX_LENGTH) | |
| # For fp16 usage | |
| model.model.half() | |
| scores = model.predict(tokenized_pairs) | |
| # 0.982, 0.2444, 0.9253 | |
| print(scores) | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| # For fp16 usage | |
| model.half() | |
| features = tokenizer(tokenized_pairs, padding=True, truncation="longest_first", return_tensors="pt", max_length=MAX_LENGTH) | |
| model.eval() | |
| with torch.no_grad(): | |
| model_predictions = model(**features, return_dict=True) | |
| logits = model_predictions.logits | |
| logits = torch.nn.Sigmoid()(logits) | |
| scores = [logit[0] for logit in logits] | |
| # 0.9819, 0.2444, 0.9253 | |
| print(scores) | |