mtyrrell's picture
innovation classification
4fbb7a3
import re
from io import BytesIO
from openpyxl import Workbook
from openpyxl.styles import Font, NamedStyle, PatternFill
from openpyxl.styles.differential import DifferentialStyle
import logging
from logging.handlers import RotatingFileHandler
import os
import configparser
def setup_logging():
# Set up logging
log_dir = 'logs'
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, 'app.log')
# Create a RotatingFileHandler
file_handler = RotatingFileHandler(log_file, maxBytes=1024 * 1024, backupCount=5)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
# Configure the root logger
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[file_handler, logging.StreamHandler()])
# Return a logger instance
return logging.getLogger(__name__)
def getconfig(configfile_path: str):
"""
Read the config file
Params
----------------
configfile_path: file path of .cfg file
"""
config = configparser.ConfigParser()
try:
config.read_file(open(configfile_path))
return config
except:
logging.warning("config file not found")
# Function for creating Upload template file
def create_excel():
wb = Workbook()
sheet = wb.active
sheet.title = "template"
columns = ['id',
'organization',
'scope',
'technology',
'financial',
'barrier',
'technology_rationale',
'project_rationale',
'project_objectives',
'maf_funding_requested',
'contributions_public_sector',
'contributions_private_sector',
'contributions_other',
'mitigation_potential']
sheet.append(columns) # Appending columns to the first row
# formatting
for c in sheet['A1:N4'][0]:
c.fill = PatternFill('solid', fgColor = 'bad8e1')
c.font = Font(bold=True)
# Save to a BytesIO object
output = BytesIO()
wb.save(output)
return output.getvalue()
# Function to clean text
def clean_text(input_text):
cleaned_text = re.sub(r"[^a-zA-Z0-9\s.,:;!?()\-\n]", "", input_text)
cleaned_text = re.sub(r"x000D", "", cleaned_text)
cleaned_text = re.sub(r"\s+", " ", cleaned_text)
cleaned_text = re.sub(r"\n+", "\n", cleaned_text)
return cleaned_text
# # Function for extracting classifications for each SECTOR label
def extract_predicted_labels(output, ordinal_selection=1, threshold=0.5):
# verify output is a list of dictionaries
if isinstance(output, list) and all(isinstance(item, dict) for item in output):
# filter items with scores above the threshold
filtered_items = [item for item in output if item.get('score', 0) > threshold]
# sort the filtered items by score in descending order
sorted_items = sorted(filtered_items, key=lambda x: x.get('score', 0), reverse=True)
# extract the highest and second-highest labels
if len(sorted_items) >= 2:
highest_label = sorted_items[0].get('label')
second_highest_label = sorted_items[1].get('label')
elif len(sorted_items) == 1:
highest_label = sorted_items[0].get('label')
second_highest_label = None
else:
print("Warning: Less than two items above the threshold in the current list.")
highest_label = None
second_highest_label = None
else:
print("Error: Inner data is not formatted correctly. Each item must be a dictionary.")
highest_label = None
second_highest_label = None
# Output dictionary of highest and second-highest labels to the all_predicted_labels list
predicted_labels = {"SECTOR1": highest_label, "SECTOR2": second_highest_label}
return predicted_labels