| import json
|
| from bs4 import BeautifulSoup
|
| import re
|
| from tqdm import tqdm
|
| import sys
|
| import question_categorizer as qc
|
| import numpy as np
|
| from question_categorizer import TextClassificationModel
|
|
|
| qc_model = qc.TextClassificationModel.load_model("models/categorizer")
|
|
|
| categories = ['Geography', 'Religion', 'Philosophy', 'Trash','Mythology', 'Literature','Science', 'Social Science', 'History', 'Current Events', 'Fine Arts']
|
|
|
| def remove_newline(string):
|
| return re.sub('\n+', ' ', string)
|
|
|
| def clean_text(text, answer):
|
|
|
| text = re.sub(r'<.*?>', '', text)
|
|
|
|
|
| text = text.replace('?','.')
|
|
|
|
|
| text = re.sub(r'[^a-zA-Z.\s-]', '', text)
|
|
|
|
|
|
|
|
|
| try:
|
|
|
| processed_answer = answer.replace('_', ' ')
|
|
|
|
|
| processed_answer = re.sub(r'\([^)]*\)', '', processed_answer)
|
|
|
|
|
| text = re.sub(re.escape(processed_answer), '', text, flags=re.IGNORECASE)
|
| except Exception as e:
|
| print("An error occurred during text cleaning:", e)
|
| print("Text:", text)
|
| print("Answer:", answer)
|
|
|
|
|
| text = re.sub(r'\s+', ' ', text)
|
|
|
| return text.strip()
|
|
|
| def process_data():
|
|
|
|
|
| jeopardy_data = []
|
|
|
| wiki_files = [
|
| ]
|
|
|
| question_files = [
|
| "qadata.json"]
|
|
|
| wiki_data = []
|
| question_data = []
|
|
|
| for file_path in wiki_files:
|
| with open('data/' + file_path, "r") as f:
|
| wiki_data.extend(json.load(f))
|
|
|
| for file_path in question_files:
|
| with open('data/' + file_path, "r") as f:
|
| question_data.extend(json.load(f))
|
|
|
|
|
|
|
| with open("data/training_data.json", "w") as f:
|
| training_data = []
|
|
|
|
|
| print("Processing Jeopardy data...")
|
| for entry in tqdm(jeopardy_data):
|
| question = entry["question"]
|
| answer = str(entry["answer"])
|
|
|
|
|
| soup = BeautifulSoup(question, 'html.parser')
|
| clean_question = ''.join(soup.findAll(text=True, recursive=False))
|
|
|
| question_category = []
|
|
|
|
|
| prediction = qc_model.predict(question)
|
| predictions = np.argwhere(prediction >= 1.5)[1]
|
|
|
| for prediction_ind in predictions:
|
|
|
| question_category.append(categories[prediction_ind])
|
|
|
| question_category.append('ALL')
|
|
|
|
|
|
|
| training_entry = {
|
| "text": clean_question,
|
| "answer": answer,
|
|
|
| "category": question_category
|
| }
|
|
|
| training_data.append(training_entry)
|
|
|
|
|
| print("Processing Wikipedia data...")
|
| for entry in tqdm(wiki_data):
|
| page = str(entry["page"])
|
| text = entry["text"]
|
|
|
| if(text == ""):
|
| continue
|
|
|
| text = remove_newline(text)
|
| text = clean_text(text, page)
|
|
|
| question_category = []
|
|
|
|
|
| prediction = qc_model.predict(text)
|
| predictions = np.argwhere(prediction >= 1.5)[1]
|
|
|
| for prediction_ind in predictions:
|
|
|
| question_category.append(categories[prediction_ind])
|
|
|
| question_category.append('ALL')
|
|
|
|
|
|
|
| training_entry = {
|
| "text": text,
|
| "answer": page,
|
|
|
| "category": question_category
|
| }
|
|
|
| training_data.append(training_entry)
|
|
|
| print("Processing Misc data...")
|
| for entry in tqdm(question_data):
|
|
|
| answer = str(entry["answer"])
|
| text = entry["text"]
|
|
|
| if(text == "" or answer == ""):
|
| continue
|
|
|
| text = remove_newline(text)
|
| text = clean_text(text, answer)
|
|
|
| question_category = []
|
|
|
|
|
| try:
|
| prediction = qc_model.predict(text)
|
| predictions = np.argwhere(prediction >= 1.5)[1]
|
| except:
|
| print("answer: " + str(answer))
|
| print("text:" + str(text))
|
| continue
|
|
|
| for prediction_ind in predictions:
|
|
|
| question_category.append(categories[prediction_ind])
|
|
|
| question_category.append('ALL')
|
|
|
|
|
|
|
| training_entry = {
|
| "text": text,
|
| "answer": answer,
|
|
|
| "category": question_category
|
| }
|
|
|
| training_data.append(training_entry)
|
|
|
|
|
|
|
| json.dump(training_data, f, indent=4)
|
|
|
| process_data()
|
|
|