Ginidu2003 commited on
Commit
fec7c1e
Β·
verified Β·
1 Parent(s): 0e9e335

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -63
app.py CHANGED
@@ -9,19 +9,20 @@ import re
9
  import string
10
 
11
  # ====================== NLTK SETUP ======================
 
12
  nltk.download('wordnet', quiet=True)
13
  nltk.download('punkt', quiet=True)
14
  nltk.download('punkt_tab', quiet=True)
15
 
 
16
  lemmatizer = WordNetLemmatizer()
17
 
18
  def preprocess_text(text):
19
  if not isinstance(text, str):
20
  return ""
21
  text = text.lower()
22
- # Remove punctuation except: ' " $ % ?
23
  punct_to_remove = string.punctuation.replace("'","").replace('"',"").replace("$","").replace("%","").replace("?","")
24
- text = re.sub(f"[{punct_to_remove}]", " ", text) # remove the specified ones
25
  tokens = nltk.word_tokenize(text)
26
  tokens = [word for word in tokens]
27
  tokens = [lemmatizer.lemmatize(word) for word in tokens]
@@ -29,24 +30,23 @@ def preprocess_text(text):
29
 
30
  # ====================== MODELS ======================
31
  classifier_model = "Ginidu2003/Distilbert-Base-News-classifier"
32
- qa_model = "deepset/roberta-base-squad2" # Best for news Q&A
33
 
34
- # ====================== CLASSIFICATION FUNCTION ======================
35
- import matplotlib.pyplot as plt
36
- import io
37
 
 
38
  @torch.no_grad()
39
- def classify_csv_with_chart(file):
40
  try:
41
  df = pd.read_csv(file)
42
-
43
  if 'content' not in df.columns:
44
- return "Error: CSV must have a column named 'content'", None, None
45
-
46
  df['clean_content'] = df['content'].apply(preprocess_text)
47
-
48
  classifier = pipeline("text-classification", model=classifier_model, device=-1)
49
-
50
  predictions = []
51
  for text in df['clean_content']:
52
  if not text.strip():
@@ -54,59 +54,28 @@ def classify_csv_with_chart(file):
54
  else:
55
  result = classifier(text)[0]
56
  predictions.append(result['label'])
57
-
58
  df['class'] = predictions
59
  df = df.drop(columns=['clean_content'], errors='ignore')
60
-
61
- # Save output
62
  output_file = "output.csv"
63
  df.to_csv(output_file, index=False)
64
-
65
- # ====== Category Distribution Chart ======
66
- fig, ax = plt.subplots()
67
- df['class'].value_counts().plot(kind='bar', ax=ax, color="skyblue")
68
- ax.set_title("Category Distribution")
69
- ax.set_ylabel("Count")
70
-
71
- buf = io.BytesIO()
72
- plt.savefig(buf, format="png")
73
- buf.seek(0)
74
-
75
- return f"βœ… Success! Classified {len(df)} rows", output_file, buf
76
-
77
  except Exception as e:
78
- return f"❌ Error: {str(e)}", None, None
79
 
80
  # ====================== Q&A FUNCTION ======================
81
- from transformers import AutoTokenizer, AutoModelForQuestionAnswering
82
-
83
- qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
84
- qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
85
-
86
  def answer_question(news_content, question):
87
  if not news_content.strip() or not question.strip():
88
  return "Please enter both news content and a question."
89
-
90
  try:
91
- inputs = qa_tokenizer(question, news_content, return_tensors="pt", truncation=True, max_length=512)
92
-
93
- with torch.no_grad():
94
- outputs = qa_model(**inputs)
95
-
96
- start_idx = torch.argmax(outputs.start_logits)
97
- end_idx = torch.argmax(outputs.end_logits) + 1
98
-
99
- # Clean answer - remove question repetition and special tokens
100
- answer = qa_tokenizer.decode(inputs.input_ids[0][start_idx:end_idx],
101
- skip_special_tokens=True,
102
- clean_up_tokenization_spaces=True)
103
-
104
- confidence = torch.max(torch.softmax(outputs.start_logits, dim=1)).item()
105
-
106
- return f"**Answer:** {answer.strip()}\n\n**Confidence:** {confidence:.2%}"
107
-
108
  except Exception as e:
109
- return f"Error: {str(e)}"
110
 
111
  # ====================== GRADIO INTERFACE ======================
112
  with gr.Blocks(title="Daily Mirror News Classifier") as demo:
@@ -114,27 +83,25 @@ with gr.Blocks(title="Daily Mirror News Classifier") as demo:
114
  gr.Markdown("### Section 02 - Text Analytics Assignment")
115
 
116
  with gr.Tabs():
117
- # Tab 1: Classification
118
  with gr.Tab("πŸ“Š Text Classification"):
 
119
  file_input = gr.File(label="Upload CSV", file_types=[".csv"])
120
  classify_btn = gr.Button("πŸš€ Classify News", variant="primary")
121
  output_text = gr.Textbox(label="Status")
122
  output_file = gr.File(label="Download output.csv")
123
- dist_plot = gr.Image(label="Category Distribution")
124
-
125
  classify_btn.click(
126
- fn=classify_csv_with_chart,
127
  inputs=file_input,
128
- outputs=[output_text, output_file, dist_plot]
129
  )
130
 
131
- # Tab 2: Q&A Pipeline
132
  with gr.Tab("❓ Question Answering"):
133
  gr.Markdown("Ask any question about a news article")
134
- news_input = gr.Textbox(lines=10, label="Paste News Content", placeholder="Paste the full news article here...")
135
- question_input = gr.Textbox(label="Your Question", placeholder="e.g., What is the main issue discussed?")
136
  qa_btn = gr.Button("πŸ” Get Answer", variant="primary")
137
- qa_output = gr.Textbox(label="Answer", lines=4)
138
 
139
  qa_btn.click(
140
  fn=answer_question,
@@ -144,4 +111,4 @@ with gr.Blocks(title="Daily Mirror News Classifier") as demo:
144
 
145
  gr.Markdown("Built for Text Analytics Assignment - Section 02")
146
 
147
- demo.launch()
 
9
  import string
10
 
11
  # ====================== NLTK SETUP ======================
12
+
13
  nltk.download('wordnet', quiet=True)
14
  nltk.download('punkt', quiet=True)
15
  nltk.download('punkt_tab', quiet=True)
16
 
17
+
18
  lemmatizer = WordNetLemmatizer()
19
 
20
  def preprocess_text(text):
21
  if not isinstance(text, str):
22
  return ""
23
  text = text.lower()
 
24
  punct_to_remove = string.punctuation.replace("'","").replace('"',"").replace("$","").replace("%","").replace("?","")
25
+ text = re.sub(f"[{punct_to_remove}]", " ", text)# Remove punctuation except: ' " $ % ?
26
  tokens = nltk.word_tokenize(text)
27
  tokens = [word for word in tokens]
28
  tokens = [lemmatizer.lemmatize(word) for word in tokens]
 
30
 
31
  # ====================== MODELS ======================
32
  classifier_model = "Ginidu2003/Distilbert-Base-News-classifier"
33
+ qa_model_name = "deepset/roberta-base-squad2"
34
 
35
+ # Load QA pipeline using a supported method
36
+ qa_pipeline = pipeline("document-question-answering", model=qa_model_name, device=-1)
 
37
 
38
+ # ====================== CLASSIFICATION FUNCTION ======================
39
  @torch.no_grad()
40
+ def classify_csv(file):
41
  try:
42
  df = pd.read_csv(file)
 
43
  if 'content' not in df.columns:
44
+ return "Error: CSV must have a column named 'content'", None
45
+
46
  df['clean_content'] = df['content'].apply(preprocess_text)
47
+
48
  classifier = pipeline("text-classification", model=classifier_model, device=-1)
49
+
50
  predictions = []
51
  for text in df['clean_content']:
52
  if not text.strip():
 
54
  else:
55
  result = classifier(text)[0]
56
  predictions.append(result['label'])
57
+
58
  df['class'] = predictions
59
  df = df.drop(columns=['clean_content'], errors='ignore')
60
+
 
61
  output_file = "output.csv"
62
  df.to_csv(output_file, index=False)
63
+
64
+ return f"βœ… Success! Classified {len(df)} rows", output_file
 
 
 
 
 
 
 
 
 
 
 
65
  except Exception as e:
66
+ return f"❌ Error: {str(e)}", None
67
 
68
  # ====================== Q&A FUNCTION ======================
 
 
 
 
 
69
  def answer_question(news_content, question):
70
  if not news_content.strip() or not question.strip():
71
  return "Please enter both news content and a question."
 
72
  try:
73
+ result = qa_pipeline(question=question, context=news_content)
74
+ answer = result[0]['answer'] if isinstance(result, list) else result['answer']
75
+ score = result[0]['score'] if isinstance(result, list) else result['score']
76
+ return f"**Answer:** {answer}\n\n**Confidence:** {score:.2%}"
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  except Exception as e:
78
+ return f"Error processing question: {str(e)}"
79
 
80
  # ====================== GRADIO INTERFACE ======================
81
  with gr.Blocks(title="Daily Mirror News Classifier") as demo:
 
83
  gr.Markdown("### Section 02 - Text Analytics Assignment")
84
 
85
  with gr.Tabs():
 
86
  with gr.Tab("πŸ“Š Text Classification"):
87
+ gr.Markdown("Upload CSV with `content` column")
88
  file_input = gr.File(label="Upload CSV", file_types=[".csv"])
89
  classify_btn = gr.Button("πŸš€ Classify News", variant="primary")
90
  output_text = gr.Textbox(label="Status")
91
  output_file = gr.File(label="Download output.csv")
92
+
 
93
  classify_btn.click(
94
+ fn=classify_csv,
95
  inputs=file_input,
96
+ outputs=[output_text, output_file]
97
  )
98
 
 
99
  with gr.Tab("❓ Question Answering"):
100
  gr.Markdown("Ask any question about a news article")
101
+ news_input = gr.Textbox(lines=12, label="Paste News Content", placeholder="Paste the full news article here...")
102
+ question_input = gr.Textbox(label="Your Question", placeholder="e.g. What is the main topic?")
103
  qa_btn = gr.Button("πŸ” Get Answer", variant="primary")
104
+ qa_output = gr.Textbox(label="Answer", lines=5)
105
 
106
  qa_btn.click(
107
  fn=answer_question,
 
111
 
112
  gr.Markdown("Built for Text Analytics Assignment - Section 02")
113
 
114
+ demo.launch()