euler03 commited on
Commit
2be9e02
·
verified ·
1 Parent(s): 3e43377

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
+
5
+ device = "cuda" if torch.cuda.is_available() else "cpu"
6
+ print(f"Use of: {device}")
7
+
8
+ # Available models
9
+ MODELS = {
10
+ "Aubins/distil-bumble-bert": "Aubins/distil-bumble-bert",
11
+ # Add models here
12
+ }
13
+
14
+ # Labels mapping
15
+ id2label = {0: "BIASED", 1: "NEUTRAL"}
16
+ label2id = {"BIASED": 0, "NEUTRAL": 1}
17
+
18
+ # Cache for loaded models
19
+ loaded_models = {}
20
+
21
+ def load_model(model_name: str):
22
+ """
23
+ Load a model and its tokenizer if not already cached
24
+
25
+ Args:
26
+ model_name (str): The name of the model to load.
27
+
28
+ Returns:
29
+ model, tokenizer: The loaded model and tokenizer
30
+ """
31
+ if model_name not in loaded_models:
32
+ try:
33
+ model_path = MODELS[model_name]
34
+
35
+ # Load model and tokenizer
36
+ model = AutoModelForSequenceClassification.from_pretrained(
37
+ model_path,
38
+ num_labels=2,
39
+ id2label=id2label,
40
+ label2id=label2id
41
+ ).to(device)
42
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
43
+
44
+ loaded_models[model_name] = (model, tokenizer)
45
+ return model, tokenizer
46
+
47
+ except Exception as e:
48
+ return f"Error loading model: {str(e)}"
49
+
50
+ return loaded_models[model_name]
51
+
52
+ def analyze_text(text: str, model_name: str):
53
+ """
54
+ Analyzes text for bias and neutrality
55
+
56
+ Args:
57
+ text (str): The text to analyze.
58
+ model_name (str): The name of the model to use.
59
+
60
+ Returns:
61
+ dict, str: A dictionary of confidence scores for each label, and a message.
62
+ """
63
+ if not text.strip():
64
+ return {"Empty text": 1.0}, "Please enter a text to be analyzed."
65
+
66
+ # Load model
67
+ result = load_model(model_name)
68
+ if isinstance(result, str):
69
+ return {"Error": 1.0}, result
70
+
71
+ model, tokenizer = result
72
+
73
+ try:
74
+ # Tokenization
75
+ inputs = tokenizer(
76
+ text,
77
+ return_tensors="pt",
78
+ truncation=True,
79
+ padding=True,
80
+ max_length=512
81
+ )
82
+
83
+ inputs = {k: v.to(device) for k, v in inputs.items()}
84
+
85
+ # Prediction
86
+ model.eval()
87
+ with torch.no_grad():
88
+ outputs = model(**inputs)
89
+
90
+ logits = outputs.logits[0]
91
+ probabilities = torch.nn.functional.softmax(logits, dim=0)
92
+ predicted_class = torch.argmax(logits).item()
93
+
94
+ predicted_label = id2label[predicted_class]
95
+
96
+ confidence_map = {
97
+ "Neutral": probabilities[1].item(),
98
+ "Biased": probabilities[0].item()
99
+ }
100
+
101
+ status = "neutral" if predicted_class == 1 else "biased"
102
+ confidence = probabilities[predicted_class].item()
103
+ message = f"This text is classified as {status} with a confidence of {confidence:.2%}."
104
+
105
+ return confidence_map, message
106
+
107
+ except Exception as e:
108
+ return {"Error": 1.0}, f"Analysis error: {str(e)}"
109
+
110
+ # Interface Gradio
111
+ with gr.Blocks(title="Objectivity detector in texts") as app:
112
+ gr.Markdown("# Objectivity detector in texts")
113
+ gr.Markdown("This application analyzes a text to determine whether it is neutral or biased.")
114
+
115
+ with gr.Row():
116
+ with gr.Column(scale=3):
117
+ model_dropdown = gr.Dropdown(
118
+ choices=list(MODELS.keys()),
119
+ label="Select a model",
120
+ value=list(MODELS.keys())[0]
121
+ )
122
+
123
+ text_input = gr.Textbox(
124
+ placeholder="Enter the text to be analyzed...",
125
+ label="Text to analyze",
126
+ lines=10
127
+ )
128
+
129
+ analyze_button = gr.Button("Analyze the text")
130
+
131
+ with gr.Column(scale=2):
132
+ confidence_output = gr.Label(
133
+ label="Analysis results",
134
+ num_top_classes=2,
135
+ show_label=True
136
+ )
137
+
138
+ result_message = gr.Textbox(label="Detailed results")
139
+
140
+ analyze_button.click(
141
+ analyze_text,
142
+ inputs=[text_input, model_dropdown],
143
+ outputs=[confidence_output, result_message]
144
+ )
145
+
146
+ gr.Markdown("## How to use this application")
147
+ gr.Markdown("""
148
+ 1. Select an analysis model from the drop-down menu
149
+ 2. Enter or paste the text to be analyzed into the text box (in English only).
150
+ 3. Click on “Analyze the text”.
151
+ 4. The result is displayed with a visual indication
152
+ """)
153
+
154
+ if __name__ == "__main__":
155
+ app.launch()