| | import torch |
| | import nltk |
| | from nltk import pos_tag |
| | from nltk.tokenize import word_tokenize |
| | from nltk.corpus import wordnet |
| | from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel |
| | from torch import nn |
| | from itertools import chain |
| | from torch.nn import MSELoss, CrossEntropyLoss |
| | from cleantext import clean |
| | from num2words import num2words |
| | import re |
| | import string |
| | import inflect |
| |
|
| | nltk.download('punkt') |
| | nltk.download('punkt_tab') |
| | nltk.download('averaged_perceptron_tagger') |
| | nltk.download('averaged_perceptron_tagger_eng') |
| | nltk.download('wordnet') |
| | |
| |
|
| | punct_chars = list((set(string.punctuation) | {'’', '‘', '–', '—', '~', '|', '“', '”', '…', "'", "`", '_'})) |
| | punct_chars.sort() |
| | punctuation = ''.join(punct_chars) |
| | replace = re.compile('[%s]' % re.escape(punctuation)) |
| |
|
| | MATH_PREFIXES = [ |
| | "sum", |
| | "arc", |
| | "mass", |
| | "digit", |
| | "graph", |
| | "liter", |
| | "gram", |
| | "add", |
| | "angle", |
| | "scale", |
| | "data", |
| | "array", |
| | "ruler", |
| | "meter", |
| | "total", |
| | "unit", |
| | "prism", |
| | "median", |
| | "ratio", |
| | "area", |
| |
|
| | |
| | "multipl", |
| | "divid", |
| | "subtrac", |
| | "logarit", |
| | "algebr", |
| | "calcul", |
| | "matri", |
| | "vect", |
| | "geometr", |
| | "statist", |
| | "probabli", |
| | "coeffi", |
| | "measure", |
| | "simplif" |
| | ] |
| |
|
| | MATH_WORDS = [ |
| | "absolute deviation", |
| | "absolute value", |
| | "abundant number", |
| | "accurate", |
| | "acre", |
| | "acute", |
| | "add", |
| | "addend", |
| | "addition fact", |
| | "addition", |
| | "additive identity", |
| | "additive inverse", |
| | "adjacent", |
| | "algebra", |
| | "algebraic", |
| | "algorithm", |
| | "alternate interior angle", |
| | "altitude", |
| | "analog", |
| | "angle measure", |
| | "angle", |
| | "angular", |
| | "apex", |
| | "approximate", |
| | "arc", |
| | "area model", |
| | "area", |
| | "arithmetic fact", |
| | "arithmetic", |
| | "array", |
| | "associative property", |
| | "associative", |
| | "astronomical unit", |
| | "attribute", |
| | "average", |
| | "axis", |
| | "bar graph", |
| | "base of a parallelogram", |
| | "base of a prism", |
| | "base of a pyramid", |
| | "base of a triangle", |
| | "base of an exponent", |
| | "base of", |
| | "base ten", |
| | "base", |
| | "baseline", |
| | "benchmark fraction", |
| | "billion", |
| | "binomial", |
| | "bisect", |
| | "bisector", |
| | "box and whisker plot", |
| | "box plot", |
| | "capacity", |
| | "cartesian coordinate", |
| | "categorical data", |
| | "categorical", |
| | "celsius", |
| | "census", |
| | "cent", |
| | "center of a circle", |
| | "center of a dilation", |
| | "center of a sphere", |
| | "center", |
| | "centimeter", |
| | "central angle", |
| | "centroid", |
| | "chance experiment", |
| | "chance", |
| | "chord", |
| | "circle graph", |
| | "circle", |
| | "circular", |
| | "circumference", |
| | "clockwise", |
| | "coefficient", |
| | "collinear", |
| | "column matrix" |
| | "column", |
| | "combination", |
| | "combine", |
| | "common denominator", |
| | "common factor", |
| | "common fraction", |
| | "common multiple", |
| | "commutative property", |
| | "commutative", |
| | "comparison diagram", |
| | "comparison story", |
| | "compass", |
| | "complement", |
| | "complementary", |
| | "compose", |
| | "composite", |
| | "concave polygon", |
| | "concentric circles", |
| | "concentric", |
| | "cone", |
| | "congruent", |
| | "consecutive", |
| | "constant function", |
| | "constant", |
| | "continuous model of area", |
| | "continuous model of volume", |
| | "continuous", |
| | "contour", |
| | "conversion fact", |
| | "conversion factor", |
| | "convert", |
| | "convex function", |
| | "convex polygon", |
| | "coordinate", |
| | "coplanar", |
| | "corresponding", |
| | "counterclockwise", |
| | "counting numbers", |
| | "counting up subtraction", |
| | "covariance", |
| | "covariate", |
| | "cover-up method", |
| | "cross multiplication", |
| | "cross product", |
| | "cross section", |
| | "cross-section", |
| | "cube root", |
| | "cube", |
| | "cubed", |
| | "cubic unit", |
| | "cubic", |
| | "cubit", |
| | "cup", |
| | "curved surface", |
| | "customary system of measurement", |
| | "customary unit", |
| | "cylinder", |
| | "cylindrical", |
| | "data", |
| | "decagon", |
| | "decimal divisor", |
| | "decimal expanded form", |
| | "decimal fraction", |
| | "decimal point", |
| | "decimal", |
| | "decimeter", |
| | "decompose", |
| | "deficient number", |
| | "degree", |
| | "delta", |
| | "denominator", |
| | "density", |
| | "dependent event", |
| | "dependent variable", |
| | "deposit", |
| | "derivative", |
| | "determinant", |
| | "diagonal", |
| | "diameter", |
| | "difference", |
| | "differential" |
| | "digit", |
| | "digital", |
| | "dilation", |
| | "dimension", |
| | "discrete model", |
| | "displacement method", |
| | "distance", |
| | "distribution", |
| | "distributive", |
| | "divide", |
| | "divided", |
| | "divides", |
| | "dividing", |
| | "dividend", |
| | "divisibility test", |
| | "divisible by", |
| | "divisible", |
| | "division", |
| | "divisor", |
| | "dodecahedron", |
| | "dot plot", |
| | "double number line diagram", |
| | "double stem plot", |
| | "doubles fact", |
| | "edge", |
| | "egyptian multiplication", |
| | "elevation", |
| | "embed figure", |
| | "end point", |
| | "endpoint", |
| | "enlarge", |
| | "equal group", |
| | "equal part", |
| | "equal", |
| | "equality", |
| | "equation", |
| | "equidistant mark", |
| | "equilateral polygon", |
| | "equilateral triangle", |
| | "equilateral", |
| | "equivalence", |
| | "equivalent expression", |
| | "equivalent fraction", |
| | "equivalent", |
| | "error bound", |
| | "error of measurement", |
| | "estimat", |
| | "estimate", |
| | "european subtraction", |
| | "even number", |
| | "event", |
| | "expand", |
| | "expanded form", |
| | "expanded notation", |
| | "expected outcome", |
| | "expected value", |
| | "exponent", |
| | "exponential function", |
| | "exponential growth", |
| | "expression", |
| | "extended fact", |
| | "face", |
| | "fact power", |
| | "fact triangle", |
| | "factor", |
| | "factored", |
| | "factoring", |
| | "factors", |
| | "factorial", |
| | "factors of number", |
| | "fahrenheit", |
| | "false number sentence", |
| | "figurate number", |
| | "flowchart", |
| | "fluid ounce", |
| | "formula", |
| | "fraction form", |
| | "fraction", |
| | "fractional part", |
| | "fractional unit", |
| | "frequency", |
| | "fulcrum", |
| | "function machine", |
| | "function", |
| | "furlong", |
| | "gallon", |
| | "gcd", |
| | "genus", |
| | "geoboard", |
| | "geometr", |
| | "geometric solid", |
| | "geometry template", |
| | "girth", |
| | "golden ratio", |
| | "golden rectangle", |
| | "gram", |
| | "graph key", |
| | "graph", |
| | "greatest common divisor" |
| | "greatest common factor", |
| | "grouping symbol", |
| | "half circle", |
| | "half-circle", |
| | "hashmark", |
| | "height of a parallelogram or triangle", |
| | "height of", |
| | "height", |
| | "hemisphere", |
| | "heptagon", |
| | "heptagonal", |
| | "hexagon", |
| | "hexagonal", |
| | "hierarchy", |
| | "histogram", |
| | "horizontal shift", |
| | "horizontal stretch", |
| | "horizontal", |
| | "hundred", |
| | "hundredth", |
| | "hypotenuse", |
| | "hypothesis", |
| | "icosahedron", |
| | "identity function", |
| | "identity matrix", |
| | "identity property of", |
| | "identity property", |
| | "improper fraction", |
| | "inch", |
| | "incircle", |
| | "indefinite integral", |
| | "independent event", |
| | "independent variable", |
| | "index of location", |
| | "indirect measurement", |
| | "inequality", |
| | "infinity", |
| | "input", |
| | "inscribed angle", |
| | "inscribed polygon", |
| | "instance of a pattern", |
| | "integer", |
| | "intercept", |
| | "intercepted arc", |
| | "interior angle", |
| | "interior of a figure", |
| | "interpolate", |
| | "interquartile range", |
| | "intersect", |
| | "interval", |
| | "inverse operation", |
| | "inverse", |
| | "iqr", |
| | "irrational number", |
| | "irrational root", |
| | "irrational", |
| | "isometry transformation", |
| | "isosceles trapezoid", |
| | "isosceles triangle", |
| | "isosceles", |
| | "joint probability", |
| | "joint variation", |
| | "juxtapose", |
| | "key sequence", |
| | "kilogram", |
| | "kilometer", |
| | "kite", |
| | "label", |
| | "landmark", |
| | "latitude", |
| | "lattice multiplication", |
| | "lcm", |
| | "least common denominator", |
| | "least common multiple", |
| | "left to right subtraction", |
| | "leg of a right triangle", |
| | "legs", |
| | "length", |
| | "like fraction", |
| | "like terms", |
| | "line graph", |
| | "line of reflection", |
| | "line of symmetry", |
| | "line plot", |
| | "line segment", |
| | "line symmetry", |
| | "line", |
| | "linear relationship", |
| | "lines of latitude", |
| | "lines of longitude", |
| | "liter", |
| | "local maximum", |
| | "local minimum", |
| | "locus", |
| | "logarithm", |
| | "logarithmic function", |
| | "logarithmic scale", |
| | "logic", |
| | "long division", |
| | "longitude", |
| | "lowest term", |
| | "magnitude estimate", |
| | "make ten", |
| | "map legend", |
| | "map scale", |
| | "mass", |
| | "maximum", |
| | "mean absolute deviation", |
| | "mean value", |
| | "mean", |
| | "measure of center", |
| | "measure", |
| | "measurement division", |
| | "measurement error", |
| | "measurement unit", |
| | "median", |
| | "meridian bar", |
| | "meter", |
| | "meters per second", |
| | "metric system", |
| | "metric unit", |
| | "metric", |
| | "midpoint", |
| | "mile", |
| | "milliliter", |
| | "millimeter", |
| | "millisecond", |
| | "minimum", |
| | "minuend", |
| | "mirror image", |
| | "mixed number", |
| | "mixed unit", |
| | "mobius", |
| | "modal", |
| | "mode", |
| | "multipl", |
| | "multiply", |
| | "multiplied", |
| | "multiplies", |
| | "multiple", |
| | "multiplication", |
| | "multiplying", |
| | "multiplication counting principle", |
| | "multiplication diagram", |
| | "multiplication fact", |
| | "multiplication symbol", |
| | "multiplication use class", |
| | "multiplicative identity", |
| | "multiplicative inverse", |
| | "multiplier", |
| | "mutually exclusive event", |
| | "natural number", |
| | "negative association", |
| | "negative exponent", |
| | "negative number", |
| | "negative rational number", |
| | "nested parentheses", |
| | "net score", |
| | "net weight", |
| | "net", |
| | "nonagon", |
| | "nonconvex polygon", |
| | "nonlinear", |
| | "normal distribution", |
| | "normal span", |
| | "normal", |
| | "number bond", |
| | "number disk", |
| | "number grid", |
| | "number line", |
| | "number path", |
| | "number sentence", |
| | "number sequence", |
| | "numeral", |
| | "numeration", |
| | "numerator", |
| | "numerical data", |
| | "numerical", |
| | "obtuse", |
| | "octagon", |
| | "octagonal", |
| | "octahedron", |
| | "odd number", |
| | "open proportion", |
| | "operation symbol", |
| | "operational", |
| | "opposite angle", |
| | "opposite change rule", |
| | "opposite of a number", |
| | "opposite side", |
| | "opposite vertex", |
| | "opposite", |
| | "order of magnitude", |
| | "order of operations", |
| | "order of rotation symmetry", |
| | "order of", |
| | "ordered pair", |
| | "ordered", |
| | "ordinal number", |
| | "orthogonal", |
| | "ounce", |
| | "outlier", |
| | "pace", |
| | "pan balance", |
| | "parabola", |
| | "parallel lines", |
| | "parallel plane", |
| | "parallel", |
| | "parallelogram", |
| | "parentheses", |
| | "part to part ratio", |
| | "part to whole ratio", |
| | "part whole fraction", |
| | "partial differences subtraction", |
| | "partial product", |
| | "partial products multiplication", |
| | "partial quotients division", |
| | "partial sums addition", |
| | "partition", |
| | "partitive division", |
| | "parts and total diagram", |
| | "pentagon", |
| | "pentagonal", |
| | "per capita", |
| | "per unit rate", |
| | "per", |
| | "percent circle", |
| | "percent", |
| | "percentage", |
| | "perfect number", |
| | "perfect square", |
| | "perfect triangle", |
| | "perimeter", |
| | "permutation", |
| | "perpendicular", |
| | "perpetual calendar", |
| | "pi", |
| | "picture graph", |
| | "pie graph", |
| | "pint", |
| | "pivot", |
| | "place value", |
| | "plane figure", |
| | "plane", |
| | "point symmetry", |
| | "point", |
| | "polar coordinate", |
| | "polygon", |
| | "polyhedron", |
| | "polynominal" |
| | "population density", |
| | "population", |
| | "positive association", |
| | "positive number", |
| | "pound", |
| | "power", |
| | "precise", |
| | "predict", |
| | "prediction line", |
| | "preimage", |
| | "prime factor", |
| | "prime factorization", |
| | "prime meridian", |
| | "prime number", |
| | "prism", |
| | "probability meter", |
| | "probability tree diagram", |
| | "probability", |
| | "product", |
| | "proper factor", |
| | "proper fraction", |
| | "property", |
| | "proportion", |
| | "proportional", |
| | "proportionality", |
| | "protractor", |
| | "pyramid", |
| | "pythagorean theorem", |
| | "quadrangle", |
| | "quadrant", |
| | "quadratic", |
| | "quadrilateral", |
| | "quart", |
| | "quarter circle", |
| | "quarter of", |
| | "quarter-circle", |
| | "quartile", |
| | "quick common denominator", |
| | "quotient", |
| | "quotitive division", |
| | "radian", |
| | "radius of" |
| | "radius", |
| | "random draw", |
| | "random experiment", |
| | "random number", |
| | "random sample", |
| | "random", |
| | "range", |
| | "rank", |
| | "rate diagram", |
| | "rate multiplication ", |
| | "rate of change", |
| | "rate unit", |
| | "rate", |
| | "ratio of", |
| | "ratio", |
| | "rational equation", |
| | "rational number", |
| | "ray", |
| | "real number", |
| | "recall survey", |
| | "reciprocal", |
| | "rectang", |
| | "rectangle", |
| | "rectangular array", |
| | "rectangular coordinate grid", |
| | "rectangular prism", |
| | "rectangular pyramid", |
| | "rectangular", |
| | "rectilinear figure", |
| | "reflection", |
| | "reflex angle", |
| | "region", |
| | "regular polygon", |
| | "regular polyhedron", |
| | "regular tessellation", |
| | "relation symbol", |
| | "relative frequency", |
| | "remainder", |
| | "repeated addition", |
| | "repeating decimal", |
| | "representative", |
| | "revolution", |
| | "rhombus", |
| | "right angle", |
| | "right cone", |
| | "right cylinder", |
| | "right prism", |
| | "right pyramid", |
| | "right triangle", |
| | "rigid transformation", |
| | "roman numerals", |
| | "root", |
| | "rotate", |
| | "rotation symmetry", |
| | "rotation", |
| | "round off", |
| | "round-off", |
| | "ruler", |
| | "same change rule for subtraction", |
| | "sample", |
| | "scalar", |
| | "scale factor", |
| | "scale model", |
| | "scale of a map", |
| | "scale of a number line", |
| | "scale", |
| | "scaled graph", |
| | "scaled", |
| | "scalene triangle", |
| | "scalene", |
| | "scatter plot", |
| | "scattergram", |
| | "sector", |
| | "segment", |
| | "semi-circle", |
| | "semicircle", |
| | "sequence", |
| | "set", |
| | "sign", |
| | "significant digit", |
| | "significant figure", |
| | "similar figures", |
| | "similar", |
| | "simpler form", |
| | "simplify", |
| | "simulation", |
| | "situtation diagram", |
| | "skew line", |
| | "slanted", |
| | "slide rule", |
| | "slope", |
| | "solid figure", |
| | "solution", |
| | "span", |
| | "speed", |
| | "sphere", |
| | "square root", |
| | "square unit", |
| | "square", |
| | "squared", |
| | "stacked bar graph", |
| | "standard form", |
| | "standard unit", |
| | "statistic", |
| | "stem and leaf plot", |
| | "step graph", |
| | "straight angle", |
| | "straightedge", |
| | "subset of" |
| | "substitute", |
| | "subtract", |
| | "subtrahend", |
| | "sum of", |
| | "sum", |
| | "supplementary angle", |
| | "surface area", |
| | "surface", |
| | "survey", |
| | "symmetric", |
| | "symmetry", |
| | "system of equation", |
| | "system of", |
| | "table", |
| | "take from ten", |
| | "tally", |
| | "tangent circle", |
| | "tangent", |
| | "tangram", |
| | "tape diagram", |
| | "temperature", |
| | "template", |
| | "tens place", |
| | "tenth", |
| | "term", |
| | "terminating decimal", |
| | "tessellat", |
| | "tessellate", |
| | "tessellation", |
| | "tetrahedron", |
| | "tetromino", |
| | "theorem", |
| | "thermometer", |
| | "thousand", |
| | "thousandth", |
| | "tile", |
| | "tiling", |
| | "time graph", |
| | "timeline", |
| | "top heavy fraction", |
| | "topological", |
| | "topology", |
| | "total area", |
| | "total of", |
| | "total surface", |
| | "total volume", |
| | "trade first subtraction", |
| | "transformation", |
| | "translation", |
| | "transversal", |
| | "trapezoid", |
| | "tree diagram", |
| | "triangle", |
| | "triangular", |
| | "true number sentence", |
| | "truncate", |
| | "twin prime", |
| | "two-way table", |
| | "unit cube", |
| | "unit form", |
| | "unit fraction", |
| | "unit interval", |
| | "unit price", |
| | "unit rate", |
| | "unit square", |
| | "unit", |
| | "unknown", |
| | "unlike denominator", |
| | "unlike fraction", |
| | "value", |
| | "vanishing ", |
| | "variability", |
| | "variable", |
| | "velocity", |
| | "venn diagram", |
| | "vernal equinox", |
| | "vertex", |
| | "vertical", |
| | "volume of", |
| | "volume", |
| | "weight", |
| | "whole number", |
| | "whole unit", |
| | "whole", |
| | "width", |
| | "withdrawal", |
| | "word form", |
| | "x axes", |
| | "x axis", |
| | "x intercept", |
| | "x-axes", |
| | "x-axis", |
| | "y axes", |
| | "y axis", |
| | "y intercept", |
| | "y-axes", |
| | "y-axis", |
| | "y-intercept", |
| | "yard", |
| | "zero property of multiplication", |
| | "zero", |
| | ] |
| |
|
| | PLURAL_TO_SINGULAR_EXCLUSIONS = [ |
| | "axis", |
| | "continuous", |
| | "data", |
| | "minus", |
| | "miss", |
| | "plus", |
| | "yes", |
| | ] |
| |
|
| | p = inflect.engine() |
| |
|
| | def is_plural_regex(word): |
| | """Detect if a word is plural using common pluralization rules.""" |
| | |
| | return re.search(r'(s$|es$|ies$)', word.lower()) and not re.search(r'(ss$)', word.lower()) |
| |
|
| | def is_plural_wordnet(word): |
| | |
| | singular_synsets = wordnet.synsets(word, pos=wordnet.NOUN) |
| | plural_synsets = wordnet.synsets(word.rstrip('s'), pos=wordnet.NOUN) |
| | return len(plural_synsets) > len(singular_synsets) |
| |
|
| | def is_plural_pos(word): |
| | """Determine if a word is plural using NLTK's part-of-speech tagging.""" |
| | |
| | tokens = word_tokenize(word) |
| | |
| | pos = pos_tag(tokens)[0][1] |
| | |
| | return pos in ["NNS", "NNPS"] |
| |
|
| | def is_plural(word): |
| | """Check if a word is plural.""" |
| | if word in PLURAL_TO_SINGULAR_EXCLUSIONS: |
| | return False |
| | return is_plural_regex(word) or is_plural_pos(word) or is_plural_wordnet(word) |
| |
|
| | def singular_to_plural(word): |
| | """Convert singular words to plural using inflect.""" |
| | plural = p.plural(word) |
| | return plural or word |
| |
|
| | def plural_to_singular(word): |
| | """Convert plural word to singular using inflect.""" |
| | if is_plural(word): |
| | return p.singular_noun(word) or word |
| | return word |
| |
|
| | plural_MATH_WORDS = [singular_to_plural(word) for word in MATH_WORDS] |
| |
|
| | MATH_WORDS += plural_MATH_WORDS |
| |
|
| | def get_num_words(text): |
| | if not isinstance(text, str): |
| | print("%s is not a string" % text) |
| | text = replace.sub(' ', text) |
| | text = re.sub(r'\s+', ' ', text) |
| | text = text.strip() |
| | text = re.sub(r'\[.+\]', " ", text) |
| | return len(text.split()) |
| |
|
| | def number_to_words(num): |
| | try: |
| | return num2words(re.sub(",", "", num)) |
| | except: |
| | return num |
| |
|
| |
|
| | clean_str = lambda s: clean(s, |
| | fix_unicode=True, |
| | to_ascii=True, |
| | lower=True, |
| | no_line_breaks=True, |
| | no_urls=True, |
| | no_emails=True, |
| | no_phone_numbers=True, |
| | no_numbers=True, |
| | no_digits=False, |
| | no_currency_symbols=False, |
| | no_punct=False, |
| | replace_with_url="<URL>", |
| | replace_with_email="<EMAIL>", |
| | replace_with_phone_number="<PHONE>", |
| | replace_with_number=lambda m: number_to_words(m.group()), |
| | replace_with_digit="0", |
| | replace_with_currency_symbol="<CUR>", |
| | lang="en" |
| | ) |
| |
|
| | clean_str_nopunct = lambda s: clean(s, |
| | fix_unicode=True, |
| | to_ascii=True, |
| | lower=True, |
| | no_line_breaks=True, |
| | no_urls=True, |
| | no_emails=True, |
| | no_phone_numbers=True, |
| | no_numbers=True, |
| | no_digits=False, |
| | no_currency_symbols=False, |
| | no_punct=True, |
| | replace_with_url="<URL>", |
| | replace_with_email="<EMAIL>", |
| | replace_with_phone_number="<PHONE>", |
| | replace_with_number=lambda m: number_to_words(m.group()), |
| | replace_with_digit="0", |
| | replace_with_currency_symbol="<CUR>", |
| | lang="en" |
| | ) |
| |
|
| |
|
| |
|
| | class MultiHeadModel(BertPreTrainedModel): |
| | """Pre-trained BERT model that uses our loss functions""" |
| |
|
| | def __init__(self, config, head2size): |
| | super(MultiHeadModel, self).__init__(config, head2size) |
| | config.num_labels = 1 |
| | self.bert = BertModel(config) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| | module_dict = {} |
| | for head_name, num_labels in head2size.items(): |
| | module_dict[head_name] = nn.Linear(config.hidden_size, num_labels) |
| | self.heads = nn.ModuleDict(module_dict) |
| |
|
| | self.init_weights() |
| |
|
| | def forward(self, input_ids, token_type_ids=None, attention_mask=None, |
| | head2labels=None, return_pooler_output=False, head2mask=None, |
| | nsp_loss_weights=None): |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | output = self.bert( |
| | input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, |
| | output_attentions=False, output_hidden_states=False, return_dict=True) |
| | pooled_output = self.dropout(output["pooler_output"]).to(device) |
| |
|
| | head2logits = {} |
| | return_dict = {} |
| | for head_name, head in self.heads.items(): |
| | head2logits[head_name] = self.heads[head_name](pooled_output) |
| | head2logits[head_name] = head2logits[head_name].float() |
| | return_dict[head_name + "_logits"] = head2logits[head_name] |
| |
|
| |
|
| | if head2labels is not None: |
| | for head_name, labels in head2labels.items(): |
| | num_classes = head2logits[head_name].shape[1] |
| |
|
| | |
| | if num_classes == 1: |
| |
|
| | |
| | if head2mask is not None and head_name in head2mask: |
| | num_positives = head2labels[head2mask[head_name]].sum() |
| | if num_positives == 0: |
| | return_dict[head_name + "_loss"] = torch.tensor([0]).to(device) |
| | else: |
| | loss_fct = MSELoss(reduction='none') |
| | loss = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1)) |
| | return_dict[head_name + "_loss"] = loss.dot(head2labels[head2mask[head_name]].float().view(-1)) / num_positives |
| | else: |
| | loss_fct = MSELoss() |
| | return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1)) |
| | else: |
| | loss_fct = CrossEntropyLoss(weight=nsp_loss_weights.float()) |
| | return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name], labels.view(-1)) |
| |
|
| |
|
| | if return_pooler_output: |
| | return_dict["pooler_output"] = output["pooler_output"] |
| |
|
| | return return_dict |
| |
|
| | class InputBuilder(object): |
| | """Base class for building inputs from segments.""" |
| |
|
| | def __init__(self, tokenizer): |
| | self.tokenizer = tokenizer |
| | self.mask = [tokenizer.mask_token_id] |
| |
|
| | def build_inputs(self, history, reply, max_length): |
| | raise NotImplementedError |
| |
|
| | def mask_seq(self, sequence, seq_id): |
| | sequence[seq_id] = self.mask |
| | return sequence |
| |
|
| | @classmethod |
| | def _combine_sequence(self, history, reply, max_length, flipped=False): |
| | |
| | history = [s[:max_length] for s in history] |
| | reply = reply[:max_length] |
| | if flipped: |
| | return [reply] + history |
| | return history + [reply] |
| |
|
| |
|
| | class BertInputBuilder(InputBuilder): |
| | """Processor for BERT inputs""" |
| |
|
| | def __init__(self, tokenizer): |
| | InputBuilder.__init__(self, tokenizer) |
| | self.cls = [tokenizer.cls_token_id] |
| | self.sep = [tokenizer.sep_token_id] |
| | self.model_inputs = ["input_ids", "token_type_ids", "attention_mask"] |
| | self.padded_inputs = ["input_ids", "token_type_ids"] |
| | self.flipped = False |
| |
|
| |
|
| | def build_inputs(self, history, reply, max_length, input_str=True): |
| | """See base class.""" |
| | if input_str: |
| | history = [self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(t)) for t in history] |
| | reply = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(reply)) |
| | sequence = self._combine_sequence(history, reply, max_length, self.flipped) |
| | sequence = [s + self.sep for s in sequence] |
| | sequence[0] = self.cls + sequence[0] |
| |
|
| | instance = {} |
| | instance["input_ids"] = list(chain(*sequence)) |
| | last_speaker = 0 |
| | other_speaker = 1 |
| | seq_length = len(sequence) |
| | instance["token_type_ids"] = [last_speaker if ((seq_length - i) % 2 == 1) else other_speaker |
| | for i, s in enumerate(sequence) for _ in s] |
| | return instance |