Spaces:
Running
Running
File size: 2,263 Bytes
11e4edf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | # -*- coding: utf-8 -*-
"""
Correctif de compatibilité fastcoref ↔ transformers >= 5.x.
fastcoref 2.1.6 a été écrit pour transformers 4.x. Sous transformers 5.x,
le chargement de LingMess/FCoref échoue pour deux raisons :
1. Longformer ne supporte pas l'attention SDPA → il faut forcer
`attn_implementation="eager"` à la construction du modèle de base.
2. transformers 5.x attend un attribut de classe `all_tied_weights_keys`
que fastcoref ne définit pas.
Ce script applique les deux correctifs DIRECTEMENT dans les fichiers
installés de fastcoref. Il est idempotent : relançable sans risque.
Usage :
python patch_fastcoref.py
"""
import os
import re
import fastcoref
CIBLES = ["coref_models/modeling_lingmess.py",
"coref_models/modeling_fcoref.py"]
CLASSES = {
"modeling_lingmess.py": "LingMessModel",
"modeling_fcoref.py": "FCorefModel",
}
def patch_fichier(chemin):
nom = os.path.basename(chemin)
with open(chemin, "r", encoding="utf-8") as f:
src = f.read()
original = src
# 1) forcer attn_implementation="eager"
if 'AutoModel.from_config(config, attn_implementation="eager")' not in src:
src = src.replace(
"AutoModel.from_config(config)",
'AutoModel.from_config(config, attn_implementation="eager")',
)
# 2) ajouter l'attribut de classe all_tied_weights_keys = {}
classe = CLASSES[nom]
if "all_tied_weights_keys" not in src:
src = re.sub(
rf"(class {classe}\([^)]*\):\n)(\s+def __init__)",
rf"\1 all_tied_weights_keys = {{}} # compat transformers>=5.x\n\n\2",
src,
count=1,
)
if src != original:
with open(chemin, "w", encoding="utf-8") as f:
f.write(src)
return "patché"
return "déjà OK"
def main():
base = os.path.dirname(fastcoref.__file__)
print(f"fastcoref trouvé : {base}")
for rel in CIBLES:
chemin = os.path.join(base, rel)
if not os.path.exists(chemin):
print(f" ⚠ introuvable : {rel}")
continue
etat = patch_fichier(chemin)
print(f" {rel} : {etat}")
print("Correctif terminé.")
if __name__ == "__main__":
main()
|