Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import re | |
| models = { | |
| "RUSpam/spam_deberta_v4": "RUSpam/spam_deberta_v4", | |
| "RUSpam/spamNS_v1": "RUSpam/spamNS_v1" | |
| } | |
| tokenizers = {} | |
| model_instances = {} | |
| for name, path in models.items(): | |
| tokenizers[name] = AutoTokenizer.from_pretrained(path) | |
| model_instances[name] = AutoModelForSequenceClassification.from_pretrained(path) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model_instances["RUSpam/spamNS_v1"] = model_instances["RUSpam/spamNS_v1"].to(device).eval() | |
| def clean_text(text): | |
| text = re.sub(r'http\S+', '', text) | |
| text = re.sub(r'[^А-Яа-я0-9 ]+', ' ', text) | |
| text = text.lower().strip() | |
| return text | |
| def predict_spam_deberta(text): | |
| tokenizer = tokenizers["RUSpam/spam_deberta_v4"] | |
| model = model_instances["RUSpam/spam_deberta_v4"] | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256) | |
| input_ids = inputs['input_ids'].to(device) | |
| attention_mask = inputs['attention_mask'].to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class = torch.argmax(logits, dim=1).item() | |
| result = "Спам" if predicted_class == 1 else "Не спам" | |
| return result | |
| def predict_spam_spamns(text): | |
| tokenizer = tokenizers["RUSpam/spamNS_v1"] | |
| model = model_instances["RUSpam/spamNS_v1"] | |
| text = clean_text(text) | |
| encoding = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt') | |
| input_ids = encoding['input_ids'].to(device) | |
| attention_mask = encoding['attention_mask'].to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask=attention_mask).logits | |
| pred = torch.sigmoid(outputs).cpu().numpy()[0][0] | |
| result = "Спам" if pred >= 0.5 else "Не спам" | |
| return result | |
| def predict_spam(text, model_choice): | |
| if model_choice == "RUSpam/spam_deberta_v4": | |
| return predict_spam_deberta(text) | |
| elif model_choice == "RUSpam/spamNS_v1": | |
| return predict_spam_spamns(text) | |
| # Создание интерфейса Gradio | |
| iface = gr.Interface( | |
| fn=predict_spam, | |
| inputs=[ | |
| gr.Textbox(lines=5, label="Введите текст"), | |
| gr.Radio(choices=list(models.keys()), label="Выберите модель", value="RUSpam/spam_deberta_v4") | |
| ], | |
| outputs=gr.Label(label="Результат"), | |
| title="Определение спама в русскоязычных текстах" | |
| ) | |
| iface.launch() | |