import string import gradio as gr import requests import torch from transformers import ( AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, ) # 设置模型目录 model_dir = "my-bert-model" # 加载模型配置、分词器和预训练模型 config = AutoConfig.from_pretrained(model_dir, num_labels=2, finetuning_task="text-classification") tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config) def inference(input_text): # 对输入文本进行分词和编码 inputs = tokenizer.batch_encode_plus( [input_text], max_length=512, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors="pt", ) # 禁用梯度计算进行推理 with torch.no_grad(): logits = model(**inputs).logits # 获取预测的类别 ID 并映射为标签 predicted_class_id = logits.argmax().item() output = model.config.id2label[predicted_class_id] return output # 定义 Gradio 交互界面 demo = gr.Interface( fn=inference, inputs=gr.Textbox(label="Input Text", scale=2, container=False), outputs=gr.Textbox(label="Output Label"), # 提供示例数据 examples = [ ["My last two weather pics from the storm on August 2nd. People packed up real fast after the temp dropped and winds picked up.", 1], ["Lying Clinton sinking! Donald Trump singing: Let's Make America Great Again!", 0], ], title="Tutorial: BERT-based Text Classificatioin", ) # 启动 Gradio 应用 demo.launch(debug=True)