Upload 3 files
Browse filesMetrum validator
- .gitattributes +1 -0
- BPE_validator_1697833311028 +3 -0
- poet_utils.py +418 -0
- validators.py +259 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
BPE_validator_1697833311028 filter=lfs diff=lfs merge=lfs -text
|
BPE_validator_1697833311028
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e16e690f718daaf133d6e78e82d416ef62a84db3f76f70592460e53da2f6a8fa
|
| 3 |
+
size 498951742
|
poet_utils.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Most Common Rhyme Schemas
|
| 2 |
+
RHYME_SCHEMES = ["ABAB", "ABBA",
|
| 3 |
+
"XAXA", "ABCB",
|
| 4 |
+
"AABB", "AABA",
|
| 5 |
+
"AAAA", "AABC",
|
| 6 |
+
'XXXX', 'AXAX',
|
| 7 |
+
"AABBCC", "AABCCB",
|
| 8 |
+
"ABABCC", 'AABCBC',
|
| 9 |
+
"AAABAB", "ABABXX"
|
| 10 |
+
"ABABCD", "ABABAB",
|
| 11 |
+
"ABABBC", "ABABCB",
|
| 12 |
+
"ABBAAB","AABABB",
|
| 13 |
+
"ABCBBB",'ABCBCD',
|
| 14 |
+
"ABBACC","AABBCD",
|
| 15 |
+
None]
|
| 16 |
+
|
| 17 |
+
NORMAL_SCHEMES = ["ABAB", "ABBA", "AABB", "AABBCC", "ABABCC", "ABBACC", "ABBAAB"]
|
| 18 |
+
|
| 19 |
+
# First 200 Most common endings
|
| 20 |
+
VERSE_ENDS = ['ní', 'ou', 'em', 'la', 'ch', 'ti', 'tí', 'je', 'li', 'al', 'ce', 'ky', 'ku', 'ně', 'jí', 'ly', 'il', 'en', 'né',
|
| 21 |
+
'lo', 'ne', 'vá', 'ny', 'se', 'na', 'ím', 'st', 'le', 'ný', 'ci', 'mi', 'ka', 'ná', 'lí', 'cí', 'ží', 'čí', 'ám',
|
| 22 |
+
'hu', 'ho', 'ří', 'dí', 'nu', 'dy', 'ší', 'ví', 'du', 'ta', 'as', 'tě', 'ře', 'ru', 'vé', 'ým', 'at', 'ek', 'el',
|
| 23 |
+
'te', 'tu', 'ká', 'ji', 'ět', 'ni', 'še', 'vy', 'dá', 'it', 'tá', 'ty', 'lý', 'lá', 'mu', 'va', 'ém', 'ěl', 'no',
|
| 24 |
+
'že', 'vu', 'ál', 'há', 'ků', 'vý', 'bě', 'hy', 'lé', 'sy', 'me', 'es', 'ra', 'ak', 'ad', 'ry', 'zí', 'et', 'rá',
|
| 25 |
+
'de', 'vě', 'ři', 'lu', 'át', 'da', 'ko', 'ha', 'té', 'to', 'ed', 'ít', 'ký', 'ši', 'íš', 'sí', 'íc', 'ze', 'si',
|
| 26 |
+
'be', 'má', 'mě', 'by', 'su', 'tý', 'ej', 'či', 'če', 'my', 'ké', 'án', 'ma', 'ům', 'or', 'nů', 'áš', 'dě', 'ec',
|
| 27 |
+
'mí', 'ev', 'ád', 'ut', 'am', 'yl', 'ul', 'tů', 'bu', 'ás', 'ba', 'ud', 'ář', 'ie', 'od', 'pí', 'ůj', 'eš', 'hý',
|
| 28 |
+
'bí', 'íž', 'dé', 'an', 'sa', 've', 'lů', 'ín', 'id', 'in', 'mů', 'di', 'hů', 'ic', 'on', 'eň', 'zy', 'ol', 'vo',
|
| 29 |
+
'ži', 'sů', 'ík', 'vi', 'oj', 'uk', 'uh', 'oc', 'iž', 'sá', 'ěv', 'dý', 'av', 'iv', 'rů', 'ot', 'py', 'mé', 'um',
|
| 30 |
+
'zd', 'dů', 'ar', 'rý', 'aň', 'sk', 'ok', 'om', 'už', 'ěk', 'ov', 'er', 'uď', 'bi', 'áz', 'ýt', 'ěm', 'ik', 'eď',
|
| 31 |
+
'ob', 'ák', 'ůh', 'ár', 'sť', 'ro', 'yt', 'ěj', 'mý', 'us', 'ěn', 'ii', 'hé', 'áj', 'pá', 'íh', 'ih', 'zi', 'bá',
|
| 32 |
+
'eč', 'ré', 'ír', 'ců', 'uj', 'dl', 'áh', 'ův', 'aj', 'eh', 'éž', 'pu', 'ýš', 'zu', 'im', 're', 'up', 'os', 'ah',
|
| 33 |
+
'rt', 'mo', 'áň', 'sl', 'íl', 'cy', 'ys', 'hl', 'oh', 'ěz', 'ěs', 'ež', 'ií', 'vů', 'kl', 'az', 'cý', 'pe', 'ěd',
|
| 34 |
+
'do', 'yn', 'šť', 'ez', 'ůl', 'ub', 'ln', 'yk', 'pý', 'ěc', 'ať', 'já', 'op', 'eb', 'áč', 'ív', 'áv', 'jů', 'sý',
|
| 35 |
+
'is', ' a', 'iť', 'ěř', 'za', 'uť', 'ěh', 'pě', 'íp', 'áž', 'ěď', 'bů', 'ep', 'iš', 'yš', 'ia', 'pa', 'un', 'ěť',
|
| 36 |
+
'pů', 'eř', 'tr', 'nt', 'pi', 'tl', 'eť', 'ju', 'oď', 'řů', 'ýr', 'rh', 'ur', 'zý', 'ěž', 'ýn', 'ip', 'bý', 'pé',
|
| 37 |
+
'íň', 'zů', 'čů', 'uč', 'éb', 'ap', 'ón', 'uř', 'ůr', 'íř', 'ač', 'co', 'íč', 'až', 'ls', 'ůž', 'ěr', 'oč', 'ič',
|
| 38 |
+
'ař', 'ěš', 'uv', 'ůz', 'oň', 'bé', 'sé', 'yč', 'áť', 'jď', 'ri', 'íť', 'oš', 'ůň', 'ék', 'uc', 'rk', 'bo', 'ýl',
|
| 39 |
+
'oť', 'íz', 'lh', 'so', 'áb', 'ja', 'ij', 'ůn', 'rv', 'žů', 'ab', 'he', 'íd', 'ér', 'uš', 'ýž', 'fá', 'rs', 'rn',
|
| 40 |
+
'iz', 'ib', 'ki', 'éd', 'év', 'rd', 'yb', 'oz', 'oř', 'ét', 'ož', 'ga', 'yň', 'rp', 'nd', 'of', 'rť', 'iď', 'ýv',
|
| 41 |
+
'yz', None]
|
| 42 |
+
# Years to bucket to
|
| 43 |
+
POET_YEARS_BUCKETS = [1800, 1820, 1840, 1860, 1880, 1900, 1920, 1940, 1960, None]
|
| 44 |
+
# Possible Meter Types
|
| 45 |
+
METER_TYPES = ["J","T","D","A","X","Y","N","H","P", None]
|
| 46 |
+
# Translation of Meter to one char types
|
| 47 |
+
METER_TRANSLATE = {
|
| 48 |
+
"J":"J",
|
| 49 |
+
"T":"T",
|
| 50 |
+
"D":"D",
|
| 51 |
+
"A":"A",
|
| 52 |
+
"X":"X",
|
| 53 |
+
"Y":"Y",
|
| 54 |
+
"hexameter": "H",
|
| 55 |
+
"pentameter": "P",
|
| 56 |
+
"N":"N"
|
| 57 |
+
}
|
| 58 |
+
# Tokenizers Special Tokens
|
| 59 |
+
PAD = "<|PAD|>"
|
| 60 |
+
UNK = "<|UNK|>"
|
| 61 |
+
EOS = "<|EOS|>"
|
| 62 |
+
# Basic Characters to consider in rhyme and syllables (43)
|
| 63 |
+
VALID_CHARS = [""," ",'a','á','b','c','č','d','ď','e','é','ě',
|
| 64 |
+
'f','g','h','i','í','j','k','l','m','n','ň',
|
| 65 |
+
'o','ó','p','q','r','ř','s','š','t','ť','u',
|
| 66 |
+
'ú','ů','v','w','x','y','ý','z','ž']
|
| 67 |
+
|
| 68 |
+
import re
|
| 69 |
+
import numpy as np
|
| 70 |
+
|
| 71 |
+
class TextManipulation:
|
| 72 |
+
"""Static class for string manipulation methods
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
_type_: str returned by all methods
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def _remove_most_nonchar(raw_text, lower_case=True):
|
| 80 |
+
"""Remove most non-alpha non-whitespace characters
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
raw_text (str): Text to manipulate
|
| 84 |
+
lower_case (bool, optional): If resulting text should be lowercase. Defaults to True.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
str: Cleaned up text
|
| 88 |
+
"""
|
| 89 |
+
text = re.sub(r'[–\„\“\’\;\:()\]\[\_\*\‘\”\'\-\—\"]+', "", raw_text)
|
| 90 |
+
return text.lower() if lower_case else text
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
def _remove_all_nonchar(raw_text):
|
| 94 |
+
"""Remove all possible non-alpha characters
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
raw_text (str): Text to manipulate
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
str: Cleaned up text
|
| 101 |
+
"""
|
| 102 |
+
sub = re.sub(r'([^\w\s]+|[0-9]+)', '', raw_text)
|
| 103 |
+
return sub
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def _year_bucketor(raw_year):
|
| 107 |
+
"""Bucketizes year string to boundaries, Bad inputs returns NaN string
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
raw_year (str): Year string to bucketize
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
_type_: Bucketized year string
|
| 114 |
+
"""
|
| 115 |
+
if TextAnalysis._is_year(raw_year) and raw_year != "NaN":
|
| 116 |
+
year_index = np.argmin(np.abs(np.asarray(POET_YEARS_BUCKETS[:-1]) - int(raw_year)))
|
| 117 |
+
return str(POET_YEARS_BUCKETS[year_index])
|
| 118 |
+
else:
|
| 119 |
+
return "NaN"
|
| 120 |
+
|
| 121 |
+
class TextAnalysis:
|
| 122 |
+
"""Static class with methods of analysis of strings
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
Union[str, bool, dict, numpy.ndarray]: Analyzed input
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
# Possible Keys if returned type is dict
|
| 129 |
+
POET_PARAM_LIST = ["RHYME", "YEAR", "METER", "LENGTH", "END", "TRUE_LENGTH", "TRUE_END"]
|
| 130 |
+
|
| 131 |
+
@staticmethod
|
| 132 |
+
def _is_meter(meter:str):
|
| 133 |
+
"""Return if string is meter type
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
meter (str): string to analyze
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
bool: If string is meter type
|
| 140 |
+
"""
|
| 141 |
+
return meter in METER_TYPES[:-1]
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
def _is_year(year:str):
|
| 145 |
+
"""Return if string is year or special NaN
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
year (str): string to analyze
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
bool: If string is year or special NaN
|
| 152 |
+
"""
|
| 153 |
+
return (year.isdigit() and int(year) > 1_000 and int(year) < 10_000) or year == "NaN"
|
| 154 |
+
|
| 155 |
+
@staticmethod
|
| 156 |
+
def _rhyme_like(rhyme:str):
|
| 157 |
+
"""Return if string is structured like rhyme schema
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
rhyme (str): string to analyze
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
bool: If string is structured like rhyme schema
|
| 164 |
+
"""
|
| 165 |
+
return (rhyme.isupper() and len(rhyme) >= 3 and len(rhyme) <= 6)
|
| 166 |
+
|
| 167 |
+
@staticmethod
|
| 168 |
+
def _rhyme_vector(rhyme:str) -> np.ndarray:
|
| 169 |
+
"""Create One-hot encoded rhyme schema vector from given string
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
rhyme (str): string to construct vector from
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
numpy.ndarray: One-hot encoded rhyme schema vector
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
rhyme_vec = np.zeros(len(RHYME_SCHEMES))
|
| 179 |
+
if rhyme in RHYME_SCHEMES:
|
| 180 |
+
rhyme_vec[RHYME_SCHEMES.index(rhyme)] = 1
|
| 181 |
+
else:
|
| 182 |
+
rhyme_vec[-1] = 1
|
| 183 |
+
|
| 184 |
+
return rhyme_vec
|
| 185 |
+
|
| 186 |
+
@staticmethod
|
| 187 |
+
def _rhyme_or_not(rhyme_str:str) -> np.ndarray:
|
| 188 |
+
"""Create vector if given rhyme string is in our list of rhyme schemas
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
rhyme_str (str): string to construct vector from
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
numpy.ndarray: Boolean flag vector
|
| 195 |
+
"""
|
| 196 |
+
rhyme_vector = np.zeros(2)
|
| 197 |
+
if rhyme_str in RHYME_SCHEMES:
|
| 198 |
+
rhyme_vector[0] = 1
|
| 199 |
+
else:
|
| 200 |
+
rhyme_vector[1] = 1
|
| 201 |
+
return rhyme_vector
|
| 202 |
+
|
| 203 |
+
@staticmethod
|
| 204 |
+
def _metre_vector(metre: str) -> np.ndarray:
|
| 205 |
+
"""Create One-hot encoded metre vector from given string
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
metre (str): string to construct vector from
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
numpy.ndarray: One-hot encoded metre vector
|
| 212 |
+
"""
|
| 213 |
+
metre_vec = np.zeros(len(METER_TYPES))
|
| 214 |
+
if metre in METER_TYPES:
|
| 215 |
+
metre_vec[METER_TYPES.index(metre)] = 1
|
| 216 |
+
else:
|
| 217 |
+
metre_vec[-2] = 1
|
| 218 |
+
return metre_vec
|
| 219 |
+
|
| 220 |
+
@staticmethod
|
| 221 |
+
def _first_line_analysis(text:str):
|
| 222 |
+
"""Analysis of parameter line for RHYME, METER, YEAR
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
text (str): parameter line string
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
dict: Dictionary with analysis result
|
| 229 |
+
"""
|
| 230 |
+
line_striped = text.strip()
|
| 231 |
+
if not line_striped:
|
| 232 |
+
return {}
|
| 233 |
+
poet_params = {}
|
| 234 |
+
# Look for each possible parameter
|
| 235 |
+
for param in line_striped.split():
|
| 236 |
+
if TextAnalysis._is_meter(param):
|
| 237 |
+
poet_params["METER"] = param
|
| 238 |
+
elif TextAnalysis._is_year(param):
|
| 239 |
+
# Year is Bucketized so to fit
|
| 240 |
+
poet_params["YEAR"] = TextManipulation._year_bucketor(param)
|
| 241 |
+
elif TextAnalysis._rhyme_like(param):
|
| 242 |
+
poet_params["RHYME"] = param
|
| 243 |
+
return poet_params
|
| 244 |
+
|
| 245 |
+
@staticmethod
|
| 246 |
+
def _is_line_length(length:str):
|
| 247 |
+
"""Return if string is number of syllables parameter
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
length (str): string to analyze
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
bool: If string is number of syllables parameter
|
| 254 |
+
"""
|
| 255 |
+
return length.isdigit() and int(length) > 1 and int(length) < 100
|
| 256 |
+
|
| 257 |
+
@staticmethod
|
| 258 |
+
def _is_line_end(end:str):
|
| 259 |
+
"""Return if string is valid ending syllable/sequence parameter
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
end (str): string to analyze
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
bool: If string is valid ending syllable/sequence parameter
|
| 266 |
+
"""
|
| 267 |
+
return end.isalpha() and len(end) <= 5
|
| 268 |
+
|
| 269 |
+
@staticmethod
|
| 270 |
+
def _continuos_line_analysis(text:str):
|
| 271 |
+
"""Analysis of Content lines for LENGTH, TRUE_LENGTH, END, TRUE_END
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
text (str): content line to analyze
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
dict: Dictionary with analysis result
|
| 278 |
+
"""
|
| 279 |
+
# Strip line of most separators and look if its empty
|
| 280 |
+
line_striped = TextManipulation._remove_most_nonchar(text).strip()
|
| 281 |
+
if not line_striped:
|
| 282 |
+
return {}
|
| 283 |
+
line_params = {}
|
| 284 |
+
# Look for parameters in Order LENGTH, END, TRUE_LENGTH, TRUE_END
|
| 285 |
+
if TextAnalysis._is_line_length(line_striped.split()[0]):
|
| 286 |
+
line_params["LENGTH"] = int(line_striped.split()[0])
|
| 287 |
+
if len(line_striped.split()) > 1 and TextAnalysis._is_line_end(line_striped.split()[1]):
|
| 288 |
+
line_params["END"] = line_striped.split()[1]
|
| 289 |
+
if len(line_striped.split()) > 3:
|
| 290 |
+
line_params["TRUE_LENGTH"] = len(SyllableMaker.syllabify(" ".join(line_striped.split()[3:])))
|
| 291 |
+
# TRUE_END needs only alpha chars, so all other chars are removed
|
| 292 |
+
line_only_char = TextManipulation._remove_all_nonchar(line_striped).strip()
|
| 293 |
+
if len(line_only_char) > 2:
|
| 294 |
+
line_params["TRUE_END"] = SyllableMaker.syllabify(line_only_char)[-1]
|
| 295 |
+
|
| 296 |
+
return line_params
|
| 297 |
+
|
| 298 |
+
@staticmethod
|
| 299 |
+
def _is_param_line(text:str):
|
| 300 |
+
"""Return if line is a Parameter line (Parameters RHYME, METER, YEAR)
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
text (str): line to analyze
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
bool: If line is a Parameter line
|
| 307 |
+
"""
|
| 308 |
+
line_striped = text.strip()
|
| 309 |
+
if not line_striped:
|
| 310 |
+
return False
|
| 311 |
+
small_analysis = TextAnalysis._first_line_analysis(line_striped)
|
| 312 |
+
return "RHYME" in small_analysis.keys() or "METER" in small_analysis.keys() or "YEAR" in small_analysis.keys()
|
| 313 |
+
|
| 314 |
+
# NON-Original code!
|
| 315 |
+
# Taken from Barbora Štěpánková
|
| 316 |
+
class SyllableMaker:
|
| 317 |
+
"""Static class with methods for separating string to list of Syllables
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
list: List of syllables
|
| 321 |
+
"""
|
| 322 |
+
|
| 323 |
+
@staticmethod
|
| 324 |
+
def syllabify(text : str) -> list[str]:
|
| 325 |
+
words = re.findall(r"[aábcčdďeéěfghiíjklmnňoópqrřsštťuúůvwxyýzžAÁBCČDĎEÉĚFGHIÍJKLMNŇOÓPQRŘSŠTŤUÚŮVWXYÝZŽäöüÄÜÖ]+", text)
|
| 326 |
+
syllables : list[str] = []
|
| 327 |
+
|
| 328 |
+
i = 0
|
| 329 |
+
while i < len(words):
|
| 330 |
+
word = words[i]
|
| 331 |
+
|
| 332 |
+
if (word.lower() == "k" or word.lower() == "v" or word.lower() == "s" or word.lower() == "z") and i < len(words) - 1 and len(words[i + 1]) > 1:
|
| 333 |
+
i += 1
|
| 334 |
+
word = word + words[i]
|
| 335 |
+
|
| 336 |
+
letter_counter = 0
|
| 337 |
+
|
| 338 |
+
# Get syllables: mask the word and split the mask
|
| 339 |
+
for syllable_mask in SyllableMaker.__split_mask(SyllableMaker.__create_word_mask(word)):
|
| 340 |
+
word_syllable = ""
|
| 341 |
+
for character in syllable_mask:
|
| 342 |
+
word_syllable += word[letter_counter]
|
| 343 |
+
letter_counter += 1
|
| 344 |
+
|
| 345 |
+
syllables.append(word_syllable)
|
| 346 |
+
|
| 347 |
+
i += 1
|
| 348 |
+
|
| 349 |
+
return syllables
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
@staticmethod
|
| 353 |
+
def __create_word_mask(word : str) -> str:
|
| 354 |
+
word = word.lower()
|
| 355 |
+
|
| 356 |
+
vocals = r"[aeiyouáéěíýóůúäöü]"
|
| 357 |
+
consonants = r"[bcčdďfghjklmnňpqrřsštťvwxzž]"
|
| 358 |
+
|
| 359 |
+
replacements = [
|
| 360 |
+
#double letters
|
| 361 |
+
('ch', 'c0'),
|
| 362 |
+
('rr', 'r0'),
|
| 363 |
+
('ll', 'l0'),
|
| 364 |
+
('nn', 'n0'),
|
| 365 |
+
('th', 't0'),
|
| 366 |
+
|
| 367 |
+
# au, ou, ai, oi
|
| 368 |
+
(r'[ao]u', '0V'),
|
| 369 |
+
(r'[ao]i','0V'),
|
| 370 |
+
|
| 371 |
+
# eu at the beginning of the word
|
| 372 |
+
(r'^eu', '0V'),
|
| 373 |
+
|
| 374 |
+
# now all vocals
|
| 375 |
+
(vocals, 'V'),
|
| 376 |
+
|
| 377 |
+
# r,l that act like vocals in syllables
|
| 378 |
+
(r'([^V])([rl])(0*[^0Vrl]|$)', r'\1V\3'),
|
| 379 |
+
|
| 380 |
+
# sp, st, sk, št, Cř, Cl, Cr, Cv
|
| 381 |
+
(r's[pt]', 's0'),
|
| 382 |
+
(r'([^V0lr]0*)[řlrv]', r'\g<1>0'),
|
| 383 |
+
(r'([^V0]0*)sk', r'\1s0'),
|
| 384 |
+
(r'([^V0]0*)št', r'\1š0'),
|
| 385 |
+
|
| 386 |
+
(consonants, 'K')
|
| 387 |
+
]
|
| 388 |
+
|
| 389 |
+
for (original, replacement) in replacements:
|
| 390 |
+
word = re.sub(original, replacement, word)
|
| 391 |
+
|
| 392 |
+
return word
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
@staticmethod
|
| 396 |
+
def __split_mask(mask : str) -> list[str]:
|
| 397 |
+
replacements = [
|
| 398 |
+
# vocal at the beginning
|
| 399 |
+
(r'(^0*V)(K0*V)', r'\1/\2'),
|
| 400 |
+
(r'(^0*V0*K0*)K', r'\1/K'),
|
| 401 |
+
|
| 402 |
+
# dividing the middle of the word
|
| 403 |
+
(r'(K0*V(K0*$)?)', r'\1/'),
|
| 404 |
+
(r'/(K0*)K', r'\1/K'),
|
| 405 |
+
(r'/(0*V)(0*K0*V)', r'/\1/\2'),
|
| 406 |
+
(r'/(0*V0*K0*)K', r'/\1/K'),
|
| 407 |
+
|
| 408 |
+
# add the last consonant to the previous syllable
|
| 409 |
+
(r'/(K0*)$', r'\1/')
|
| 410 |
+
]
|
| 411 |
+
|
| 412 |
+
for (original, replacement) in replacements:
|
| 413 |
+
mask = re.sub(original, replacement, mask)
|
| 414 |
+
|
| 415 |
+
if len(mask) > 0 and mask[-1] == "/":
|
| 416 |
+
mask = mask[0:-1]
|
| 417 |
+
|
| 418 |
+
return mask.split("/")
|
validators.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import transformers
|
| 3 |
+
import jellyfish
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from transformers import AutoModelForMaskedLM
|
| 6 |
+
from .poet_utils import RHYME_SCHEMES, METER_TYPES
|
| 7 |
+
|
| 8 |
+
from torch.utils.data import DataLoader, Dataset
|
| 9 |
+
from pytorch_optimizer import SAM,GSAM, ProportionScheduler, AdamP
|
| 10 |
+
|
| 11 |
+
class ValidatorInterface(torch.nn.Module):
|
| 12 |
+
"""Pytorch Model Interface. Abstract class for all validators
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
torch (_type_): Is child of torch.nn.Module for integration with torch and huggingface
|
| 16 |
+
"""
|
| 17 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 18 |
+
""" Constructor. As child Class needs to construct Parent
|
| 19 |
+
"""
|
| 20 |
+
super().__init__(*args, **kwargs)
|
| 21 |
+
|
| 22 |
+
def forward(self, input_ids=None, attention_mask=None, *args, **kwargs):
|
| 23 |
+
"""Compute model output and model loss
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
input_ids (_type_, optional): Model inputs. Defaults to None.
|
| 27 |
+
attention_mask (_type_, optional): Attention mask where padding starts. Defaults to None.
|
| 28 |
+
|
| 29 |
+
Raises:
|
| 30 |
+
NotImplementedError: Abstract class
|
| 31 |
+
"""
|
| 32 |
+
raise NotImplementedError()
|
| 33 |
+
|
| 34 |
+
def predict(self, input_ids=None, *args, **kwargs):
|
| 35 |
+
"""Compute model outputs
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
input_ids (_type_, optional): Model inputs. Defaults to None.
|
| 39 |
+
|
| 40 |
+
Raises:
|
| 41 |
+
NotImplementedError: Abstract class
|
| 42 |
+
"""
|
| 43 |
+
raise NotImplementedError()
|
| 44 |
+
|
| 45 |
+
def validate(self, input_ids=None, *args, **kwargs):
|
| 46 |
+
"""Validate model given some labels, Doesn't use loss
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
input_ids (_type_, optional): Model inputs. Defaults to None.
|
| 50 |
+
|
| 51 |
+
Raises:
|
| 52 |
+
NotImplementedError: Abstract class
|
| 53 |
+
"""
|
| 54 |
+
raise NotImplementedError()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class RhymeValidator(ValidatorInterface):
|
| 58 |
+
def __init__(self, pretrained_model, *args, **kwargs) -> None:
|
| 59 |
+
super().__init__(*args, **kwargs)
|
| 60 |
+
|
| 61 |
+
self.model = AutoModelForMaskedLM.from_pretrained(pretrained_model, output_hidden_states=True)
|
| 62 |
+
|
| 63 |
+
self.config = self.model.config
|
| 64 |
+
|
| 65 |
+
self.model_size = self.config.hidden_size
|
| 66 |
+
|
| 67 |
+
self.rhyme_regressor = torch.nn.Linear(self.model_size, len(RHYME_SCHEMES)) # Common Rhyme Type
|
| 68 |
+
|
| 69 |
+
self.loss_fnc = torch.nn.CrossEntropyLoss(label_smoothing=0.05)
|
| 70 |
+
|
| 71 |
+
def forward(self, input_ids=None, attention_mask=None, rhyme=None, *args, **kwargs):
|
| 72 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.type(torch.LongTensor))
|
| 73 |
+
|
| 74 |
+
last_hidden = outputs['hidden_states'][-1]
|
| 75 |
+
|
| 76 |
+
rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
| 77 |
+
|
| 78 |
+
softmaxed = torch.softmax(rhyme_regression, dim=1)
|
| 79 |
+
rhyme_loss = self.loss_fnc(softmaxed, rhyme)
|
| 80 |
+
|
| 81 |
+
return {"model_output" : softmaxed,
|
| 82 |
+
"loss": rhyme_loss + outputs.loss}
|
| 83 |
+
|
| 84 |
+
def predict(self, input_ids=None, *args, **kwargs):
|
| 85 |
+
|
| 86 |
+
outputs = self.model(input_ids=input_ids)
|
| 87 |
+
|
| 88 |
+
last_hidden = outputs['hidden_states'][-1]
|
| 89 |
+
|
| 90 |
+
rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
| 91 |
+
|
| 92 |
+
softmaxed = torch.softmax(rhyme_regression, dim=1)
|
| 93 |
+
|
| 94 |
+
return softmaxed
|
| 95 |
+
|
| 96 |
+
def validate(self, input_ids=None, rhyme=None, k:int = 2,*args, **kwargs):
|
| 97 |
+
outputs = self.model(input_ids=input_ids)
|
| 98 |
+
|
| 99 |
+
last_hidden = outputs['hidden_states'][-1]
|
| 100 |
+
|
| 101 |
+
rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
| 102 |
+
|
| 103 |
+
softmaxed = torch.softmax(rhyme_regression, dim=1)
|
| 104 |
+
|
| 105 |
+
softmaxed = softmaxed.flatten()
|
| 106 |
+
|
| 107 |
+
predicted_val = torch.argmax(softmaxed)
|
| 108 |
+
|
| 109 |
+
predicted_top_k = torch.topk(softmaxed, k).indices
|
| 110 |
+
|
| 111 |
+
label_val = torch.argmax(rhyme.flatten())
|
| 112 |
+
|
| 113 |
+
validation_true_val = (label_val == predicted_val).float().sum().numpy()
|
| 114 |
+
top_k_presence = 0
|
| 115 |
+
if label_val in predicted_top_k:
|
| 116 |
+
top_k_presence = 1
|
| 117 |
+
|
| 118 |
+
levenshtein = jellyfish.levenshtein_distance(RHYME_SCHEMES[predicted_val] if RHYME_SCHEMES[predicted_val] != None else "", RHYME_SCHEMES[label_val] if RHYME_SCHEMES[label_val] != None else "")
|
| 119 |
+
|
| 120 |
+
hit_pred = softmaxed[label_val].detach().numpy()
|
| 121 |
+
|
| 122 |
+
return {"acc" : validation_true_val,
|
| 123 |
+
"top_k" : top_k_presence,
|
| 124 |
+
"lev_distance": levenshtein,
|
| 125 |
+
"predicted_label" : hit_pred
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class MeterValidator(ValidatorInterface):
|
| 131 |
+
def __init__(self, pretrained_model, *args, **kwargs) -> None:
|
| 132 |
+
super().__init__(*args, **kwargs)
|
| 133 |
+
self.model = AutoModelForMaskedLM.from_pretrained(pretrained_model, output_hidden_states=True)
|
| 134 |
+
|
| 135 |
+
self.config = self.model.config
|
| 136 |
+
|
| 137 |
+
self.model_size = self.config.hidden_size
|
| 138 |
+
|
| 139 |
+
self.meter_regressor = torch.nn.Linear(self.model_size, len(METER_TYPES)) # Meter Type
|
| 140 |
+
|
| 141 |
+
self.loss_fnc = torch.nn.CrossEntropyLoss(label_smoothing=0.05)
|
| 142 |
+
|
| 143 |
+
def forward(self, input_ids=None, attention_mask=None, metre=None, *args, **kwargs):
|
| 144 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.type(torch.LongTensor))
|
| 145 |
+
|
| 146 |
+
last_hidden = outputs['hidden_states'][-1]
|
| 147 |
+
|
| 148 |
+
meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
| 149 |
+
|
| 150 |
+
softmaxed = torch.softmax(meter_regression, dim=1)
|
| 151 |
+
meter_loss = self.loss_fnc(softmaxed, metre)
|
| 152 |
+
|
| 153 |
+
return {"model_output" : softmaxed,
|
| 154 |
+
"loss": meter_loss + outputs.loss}
|
| 155 |
+
|
| 156 |
+
def predict(self, input_ids=None, *args, **kwargs):
|
| 157 |
+
outputs = self.model(input_ids=input_ids)
|
| 158 |
+
|
| 159 |
+
last_hidden = outputs['hidden_states'][-1]
|
| 160 |
+
|
| 161 |
+
meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
| 162 |
+
|
| 163 |
+
softmaxed = torch.softmax(meter_regression, dim=1)
|
| 164 |
+
|
| 165 |
+
return softmaxed
|
| 166 |
+
|
| 167 |
+
def validate(self, input_ids=None, metre=None, k: int=2,*args, **kwargs):
|
| 168 |
+
outputs = self.model(input_ids=input_ids)
|
| 169 |
+
|
| 170 |
+
last_hidden = outputs['hidden_states'][-1]
|
| 171 |
+
|
| 172 |
+
meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
| 173 |
+
|
| 174 |
+
softmaxed = torch.softmax(meter_regression, dim=1)
|
| 175 |
+
|
| 176 |
+
softmaxed = softmaxed.flatten()
|
| 177 |
+
|
| 178 |
+
predicted_val = torch.argmax(softmaxed)
|
| 179 |
+
|
| 180 |
+
predicted_top_k = torch.topk(softmaxed, k).indices
|
| 181 |
+
|
| 182 |
+
label_val = torch.argmax(metre.flatten())
|
| 183 |
+
|
| 184 |
+
validation_true_val = (label_val == predicted_val).float().sum().numpy()
|
| 185 |
+
top_k_presence = 0
|
| 186 |
+
if label_val in predicted_top_k:
|
| 187 |
+
top_k_presence = 1
|
| 188 |
+
|
| 189 |
+
hit_pred = softmaxed[label_val].detach().numpy()
|
| 190 |
+
|
| 191 |
+
return {"acc" : validation_true_val,
|
| 192 |
+
"top_k" : top_k_presence,
|
| 193 |
+
"predicted_label" : hit_pred
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class ValidatorTrainer:
|
| 198 |
+
def __init__(self, model: ValidatorInterface, args: dict, train_dataset: Dataset, data_collator, device):
|
| 199 |
+
self.model = model
|
| 200 |
+
self.args = args
|
| 201 |
+
self.epochs = 1 if "epochs" not in args.keys() else args["epochs"]
|
| 202 |
+
self.batch_size = 1 if "batch_size" not in args.keys() else args["batch_size"]
|
| 203 |
+
self.lr = 3e-4 if "lr" not in args.keys() else args["lr"]
|
| 204 |
+
self.weight_decay = 0.0 if "weight_decay" not in args.keys() else args['weight_decay']
|
| 205 |
+
|
| 206 |
+
self.train_loader = DataLoader(train_dataset, self.batch_size, True, collate_fn=data_collator)
|
| 207 |
+
|
| 208 |
+
# SAM Values
|
| 209 |
+
self.device = device
|
| 210 |
+
self.optimizer = SAM(self.model.parameters(), torch.optim.AdamW, lr=self.lr, weight_decay=self.weight_decay)
|
| 211 |
+
self.scheduler = transformers.get_constant_schedule_with_warmup(self.optimizer, len(train_dataset)//self.batch_size)
|
| 212 |
+
|
| 213 |
+
# GSAM Value
|
| 214 |
+
#self.device = device
|
| 215 |
+
#self.base_optim = AdamP(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
| 216 |
+
#self.scheduler = transformers.get_constant_schedule_with_warmup(self.base_optim, len(train_dataset)//self.batch_size)
|
| 217 |
+
#self.rho_scheduler= ProportionScheduler( self.scheduler, max_lr=self.lr)
|
| 218 |
+
#self.optimizer = GSAM(self.model.parameters(),self.base_optim, self.model, self.rho_scheduler, alpha=0.05)
|
| 219 |
+
|
| 220 |
+
def train(self):
|
| 221 |
+
for epoch in tqdm(range(self.epochs)):
|
| 222 |
+
self.model.train()
|
| 223 |
+
|
| 224 |
+
# SAM Attempt
|
| 225 |
+
|
| 226 |
+
for step, batch in enumerate(self.train_loader):
|
| 227 |
+
# First Pass
|
| 228 |
+
loss = self.model(input_ids=batch["input_ids"].to(self.device), attention_mask=batch["attention_mask"].to(self.device),
|
| 229 |
+
rhyme = None if batch["rhyme"] == None else batch["rhyme"].to(self.device),
|
| 230 |
+
metre = None if batch["metre"] == None else batch["metre"].to(self.device))['loss']
|
| 231 |
+
loss.backward()
|
| 232 |
+
self.optimizer.first_step(zero_grad=True)
|
| 233 |
+
# Second Pass
|
| 234 |
+
loss = self.model(input_ids=batch["input_ids"].to(self.device), attention_mask=batch["attention_mask"].to(self.device),
|
| 235 |
+
rhyme = None if batch["rhyme"] == None else batch["rhyme"].to(self.device),
|
| 236 |
+
metre = None if batch["metre"] == None else batch["metre"].to(self.device))['loss']
|
| 237 |
+
loss.backward()
|
| 238 |
+
self.optimizer.second_step(zero_grad=True)
|
| 239 |
+
self.scheduler.step()
|
| 240 |
+
|
| 241 |
+
# GSAM Attempt
|
| 242 |
+
|
| 243 |
+
#for step, batch in enumerate(self.train_loader):
|
| 244 |
+
# def closure():
|
| 245 |
+
# self.optimizer.base_optimizer.zero_grad()
|
| 246 |
+
# with torch.enable_grad():
|
| 247 |
+
# outputs = self.model(input_ids=batch["input_ids"].to(self.device), attention_mask=batch["attention_mask"].to(self.device),
|
| 248 |
+
# rhyme = None if batch["rhyme"] == None else batch["rhyme"].to(self.device),
|
| 249 |
+
# metre = None if batch["metre"] == None else batch["metre"].to(self.device))
|
| 250 |
+
# loss = torch.nn.functional.cross_entropy(outputs['model_output'].to(self.device),batch['rhyme'].to(self.device) if isinstance(self.model, RhymeValidator) else batch['metre'].to(self.device))
|
| 251 |
+
# loss.backward()
|
| 252 |
+
# return outputs['model_output'], loss.detach()
|
| 253 |
+
# predictions, loss = self.optimizer.step(closure)
|
| 254 |
+
# self.scheduler.step()
|
| 255 |
+
# self.optimizer.update_rho_t()
|
| 256 |
+
#
|
| 257 |
+
if step % 100 == 0:
|
| 258 |
+
print(f'Step {step}, loss : {loss.item()}', flush=True)
|
| 259 |
+
|