File size: 3,787 Bytes
48b307f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from transformers import FSMTForConditionalGeneration, FSMTTokenizer
from transformers import AutoModelForSequenceClassification
from lxml.html.clean import Cleaner

from transformers import AutoTokenizer
from langdetect import detect
from newspaper import Article
from PIL import Image
import streamlit as st

import requests
import torch

st.markdown("## Prediction of Fakeness by Given URL")
background = Image.open('logo.jpg')
st.image(background)

st.markdown(f"### Article URL")
text = st.text_area("Insert some url here", 
        value="https://en.globes.co.il/en/article-yandex-looks-to-expand-activities-in-israel-1001406519")

# @st.cache(allow_output_mutation=True)
# def get_models_and_tokenizers():
#     model_name = 'distilbert-base-uncased-finetuned-sst-2-english'
#     model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
#     model.eval()
#     tokenizer = AutoTokenizer.from_pretrained(model_name)
#     model.load_state_dict(torch.load('./my_saved_model/checkpoint-6320/rng_state.pth', map_location='cpu'))

#     model_name_translator = "facebook/wmt19-ru-en"
#     tokenizer_translator = FSMTTokenizer.from_pretrained(model_name_translator)
#     model_translator = FSMTForConditionalGeneration.from_pretrained(model_name_translator)
#     model_translator.eval()
#     return model, tokenizer, model_translator, tokenizer_translator
@st.cache_data()
def get_models_and_tokenizers():
    model_name = 'distilbert-base-uncased-finetuned-sst-2-english'
    checkpoint_dir = './my_saved_model/checkpoint-6320/'  # Path to your checkpoint folder
    
    # Load the classification model and tokenizer
    model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, num_labels=2)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Load the translator model and tokenizer
    model_name_translator = "facebook/wmt19-ru-en"
    tokenizer_translator = FSMTTokenizer.from_pretrained(model_name_translator)
    model_translator = FSMTForConditionalGeneration.from_pretrained(model_name_translator)
    
    model.eval()
    model_translator.eval()
    return model, tokenizer, model_translator, tokenizer_translator

model, tokenizer, model_translator, tokenizer_translator = get_models_and_tokenizers()

article = Article(text)
article.download()
article.parse()
concated_text = article.title + '. ' + article.text
lang = detect(concated_text)

st.markdown(f"### Language detection")

if lang == 'ru':
    st.markdown(f"The language of this article is {lang.upper()} so we translated it!")
    with st.spinner('Waiting for translation'):
        input_ids = tokenizer_translator.encode(concated_text, 
            return_tensors="pt", max_length=512, truncation=True)
        outputs = model_translator.generate(input_ids)
        decoded = tokenizer_translator.decode(outputs[0], skip_special_tokens=True)
        st.markdown("### Translated Text")
        st.markdown(f"{decoded[:777]}")
        concated_text = decoded
else:
    st.markdown(f"The language of this article for sure:  {lang.upper()}!")

    st.markdown("### Extracted Text")
    st.markdown(f"{concated_text[:777]}")

tokens_info = tokenizer(concated_text, truncation=True, return_tensors="pt")
with torch.no_grad():
    raw_predictions = model(**tokens_info)
softmaxed = int(torch.nn.functional.softmax(raw_predictions.logits[0], dim=0)[1] * 100)
st.markdown("### Fakeness Prediction")
st.progress(softmaxed)
st.markdown(f"This is fake by *{softmaxed}%*!")
if (softmaxed > 70):
    st.error('We would not trust this text!')
elif (softmaxed > 40):
    st.warning('We are not sure about this text!')
else:
    st.success('We would trust this text!')