| | import re |
| | from copy import deepcopy |
| | from fractions import Fraction |
| |
|
| | from deepmerge import Merger |
| | from dotmap import DotMap |
| |
|
| | from src.constants import FIELD_LABEL_NUMBER_REGEX |
| | from src.defaults import CONFIG_DEFAULTS, TEMPLATE_DEFAULTS |
| | from src.schemas.constants import FIELD_STRING_REGEX_GROUPS |
| | from src.utils.file import load_json |
| | from src.utils.validations import ( |
| | validate_config_json, |
| | validate_evaluation_json, |
| | validate_template_json, |
| | ) |
| |
|
| | OVERRIDE_MERGER = Merger( |
| | |
| | |
| | |
| | [ |
| | |
| | (dict, ["merge"]) |
| | ], |
| | |
| | |
| | ["override"], |
| | |
| | |
| | ["override"], |
| | ) |
| |
|
| |
|
| | def get_concatenated_response(omr_response, template): |
| | |
| | concatenated_response = {} |
| | for field_label, concatenate_keys in template.custom_labels.items(): |
| | custom_label = "".join([omr_response[k] for k in concatenate_keys]) |
| | concatenated_response[field_label] = custom_label |
| |
|
| | for field_label in template.non_custom_labels: |
| | concatenated_response[field_label] = omr_response[field_label] |
| |
|
| | return concatenated_response |
| |
|
| |
|
| | def open_config_with_defaults(config_path): |
| | user_tuning_config = load_json(config_path) |
| | user_tuning_config = OVERRIDE_MERGER.merge( |
| | deepcopy(CONFIG_DEFAULTS), user_tuning_config |
| | ) |
| | validate_config_json(user_tuning_config, config_path) |
| | |
| | return DotMap(user_tuning_config, _dynamic=False) |
| |
|
| |
|
| | def open_template_with_defaults(template_path): |
| | user_template = load_json(template_path) |
| | user_template = OVERRIDE_MERGER.merge(deepcopy(TEMPLATE_DEFAULTS), user_template) |
| | validate_template_json(user_template, template_path) |
| | return user_template |
| |
|
| |
|
| | def open_evaluation_with_validation(evaluation_path): |
| | user_evaluation_config = load_json(evaluation_path) |
| | validate_evaluation_json(user_evaluation_config, evaluation_path) |
| | return user_evaluation_config |
| |
|
| |
|
| | def parse_fields(key, fields): |
| | parsed_fields = [] |
| | fields_set = set() |
| | for field_string in fields: |
| | fields_array = parse_field_string(field_string) |
| | current_set = set(fields_array) |
| | if not fields_set.isdisjoint(current_set): |
| | raise Exception( |
| | f"Given field string '{field_string}' has overlapping field(s) with other fields in '{key}': {fields}" |
| | ) |
| | fields_set.update(current_set) |
| | parsed_fields.extend(fields_array) |
| | return parsed_fields |
| |
|
| |
|
| | def parse_field_string(field_string): |
| | if "." in field_string: |
| | field_prefix, start, end = re.findall(FIELD_STRING_REGEX_GROUPS, field_string)[ |
| | 0 |
| | ] |
| | start, end = int(start), int(end) |
| | if start >= end: |
| | raise Exception( |
| | f"Invalid range in fields string: '{field_string}', start: {start} is not less than end: {end}" |
| | ) |
| | return [ |
| | f"{field_prefix}{field_number}" for field_number in range(start, end + 1) |
| | ] |
| | else: |
| | return [field_string] |
| |
|
| |
|
| | def custom_sort_output_columns(field_label): |
| | label_prefix, label_suffix = re.findall(FIELD_LABEL_NUMBER_REGEX, field_label)[0] |
| | return [label_prefix, int(label_suffix) if len(label_suffix) > 0 else 0] |
| |
|
| |
|
| | def parse_float_or_fraction(result): |
| | if type(result) == str and "/" in result: |
| | result = float(Fraction(result)) |
| | else: |
| | result = float(result) |
| | return result |
| |
|