Ashendilantha commited on
Commit
0dd8629
·
verified ·
1 Parent(s): c1691a3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -0
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import streamlit as st
4
+ import pandas as pd
5
+ import numpy as np
6
+ import re
7
+ import nltk
8
+ from nltk.corpus import stopwords
9
+ from nltk.tokenize import word_tokenize
10
+ from nltk.stem import WordNetLemmatizer
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
13
+ import requests
14
+ from io import BytesIO
15
+
16
+ # Set page configuration
17
+ st.set_page_config(page_title="News Analysis App", layout="wide")
18
+
19
+ # Download required NLTK resources
20
+ @st.cache_resource
21
+ def download_nltk_resources():
22
+ nltk.download('punkt')
23
+ nltk.download('stopwords')
24
+ nltk.download('wordnet')
25
+
26
+ download_nltk_resources()
27
+
28
+ # Initialize preprocessor components
29
+ stop_words = set(stopwords.words('english'))
30
+ lemmatizer = WordNetLemmatizer()
31
+
32
+ # Load the fine-tuned model for classification
33
+ @st.cache_resource
34
+ def load_classification_model():
35
+ model_name = "Oneli/News_Classification" # Replace with your actual model path
36
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
38
+ return model, tokenizer
39
+
40
+ # Load Q&A pipeline
41
+ @st.cache_resource
42
+ def load_qa_pipeline():
43
+ qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
44
+ return qa_pipeline
45
+
46
+ # Text preprocessing function
47
+ def preprocess_text(text):
48
+ if pd.isna(text):
49
+ return ""
50
+
51
+ # Convert to lowercase
52
+ text = text.lower()
53
+
54
+ # Remove URLs
55
+ text = re.sub(r'http\S+|www\S+|https\S+', '', text)
56
+
57
+ # Remove HTML tags
58
+ text = re.sub(r'<.*?>', '', text)
59
+
60
+ # Remove special characters and numbers
61
+ text = re.sub(r'[^a-zA-Z\s]', '', text)
62
+
63
+ # Tokenize
64
+ tokens = word_tokenize(text)
65
+
66
+ # Remove stopwords and lemmatize
67
+ cleaned_tokens = [lemmatizer.lemmatize(token) for token in tokens if token not in stop_words]
68
+
69
+ # Join tokens back into text
70
+ cleaned_text = ' '.join(cleaned_tokens)
71
+
72
+ return cleaned_text
73
+
74
+ # Function to classify news articles
75
+ def classify_news(df, model, tokenizer):
76
+ # Preprocess the text
77
+ df['cleaned_content'] = df['content'].apply(preprocess_text)
78
+
79
+ # Prepare for classification
80
+ texts = df['cleaned_content'].tolist()
81
+
82
+ # Get predictions
83
+ predictions = []
84
+ batch_size = 16
85
+
86
+ for i in range(0, len(texts), batch_size):
87
+ batch_texts = texts[i:i+batch_size]
88
+ inputs = tokenizer(batch_texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
89
+
90
+ with torch.no_grad():
91
+ outputs = model(**inputs)
92
+ logits = outputs.logits
93
+ batch_predictions = torch.argmax(logits, dim=1).tolist()
94
+ predictions.extend(batch_predictions)
95
+
96
+ # Map numeric predictions back to class labels
97
+ id2label = model.config.id2label
98
+ df['class'] = [id2label[pred] for pred in predictions]
99
+
100
+ return df
101
+
102
+ # Main app
103
+ def main():
104
+ st.title("News Analysis Application")
105
+
106
+ # Sidebar for navigation
107
+ st.sidebar.title("Navigation")
108
+ app_mode = st.sidebar.radio("Choose the app mode", ["News Classification", "Question Answering"])
109
+
110
+ if app_mode == "News Classification":
111
+ st.header("News Article Classification")
112
+ st.write("Upload a CSV file containing news articles to classify them into categories.")
113
+
114
+ # File upload
115
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
116
+
117
+ if uploaded_file is not None:
118
+ # Load the data
119
+ df = pd.read_csv(uploaded_file)
120
+
121
+ # Display sample of the data
122
+ st.subheader("Sample of uploaded data")
123
+ st.dataframe(df.head())
124
+
125
+ # Check if the required column exists
126
+ if 'content' not in df.columns:
127
+ st.error("The CSV file must contain a 'content' column with the news articles text.")
128
+ else:
129
+ # Load model and tokenizer
130
+ with st.spinner("Loading classification model..."):
131
+ model, tokenizer = load_classification_model()
132
+
133
+ # Classify button
134
+ if st.button("Classify Articles"):
135
+ with st.spinner("Classifying news articles..."):
136
+ # Perform classification
137
+ result_df = classify_news(df, model, tokenizer)
138
+
139
+ # Display results
140
+ st.subheader("Classification Results")
141
+ st.dataframe(result_df[['content', 'class']])
142
+
143
+ # Save to CSV
144
+ csv = result_df.to_csv(index=False)
145
+ st.download_button(
146
+ label="Download output.csv",
147
+ data=csv,
148
+ file_name="output.csv",
149
+ mime="text/csv"
150
+ )
151
+
152
+ # Show distribution of classes
153
+ st.subheader("Class Distribution")
154
+ class_counts = result_df['class'].value_counts()
155
+ st.bar_chart(class_counts)
156
+
157
+ elif app_mode == "Question Answering":
158
+ st.header("News Article Q&A")
159
+ st.write("Ask questions about news content and get answers using a Q&A model.")
160
+
161
+ # Text area for news content
162
+ news_content = st.text_area("Paste news article content here:", height=200)
163
+
164
+ # Question input
165
+ question = st.text_input("Enter your question about the article:")
166
+
167
+ if news_content and question:
168
+ # Load QA pipeline
169
+ with st.spinner("Loading Q&A model..."):
170
+ qa_pipeline = load_qa_pipeline()
171
+
172
+ # Get answer
173
+ if st.button("Get Answer"):
174
+ with st.spinner("Finding answer..."):
175
+ result = qa_pipeline(question=question, context=news_content)
176
+
177
+ # Display results
178
+ st.subheader("Answer")
179
+ st.write(result["answer"])
180
+
181
+ st.subheader("Confidence")
182
+ st.progress(float(result["score"]))
183
+ st.write(f"Confidence Score: {result['score']:.4f}")
184
+
185
+ if __name__ == "__main__":
186
+ main()