shomez commited on
Commit
2613acc
·
verified ·
1 Parent(s): 9a73983

Upload codec.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. codec.py +118 -0
codec.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .core import encode, decode, alabel, ulabel, IDNAError
2
+ import codecs
3
+ import re
4
+ from typing import Any, Tuple, Optional
5
+
6
+ _unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]')
7
+
8
+ class Codec(codecs.Codec):
9
+
10
+ def encode(self, data: str, errors: str = 'strict') -> Tuple[bytes, int]:
11
+ if errors != 'strict':
12
+ raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
13
+
14
+ if not data:
15
+ return b"", 0
16
+
17
+ return encode(data), len(data)
18
+
19
+ def decode(self, data: bytes, errors: str = 'strict') -> Tuple[str, int]:
20
+ if errors != 'strict':
21
+ raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
22
+
23
+ if not data:
24
+ return '', 0
25
+
26
+ return decode(data), len(data)
27
+
28
+ class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
29
+ def _buffer_encode(self, data: str, errors: str, final: bool) -> Tuple[bytes, int]:
30
+ if errors != 'strict':
31
+ raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
32
+
33
+ if not data:
34
+ return b'', 0
35
+
36
+ labels = _unicode_dots_re.split(data)
37
+ trailing_dot = b''
38
+ if labels:
39
+ if not labels[-1]:
40
+ trailing_dot = b'.'
41
+ del labels[-1]
42
+ elif not final:
43
+ # Keep potentially unfinished label until the next call
44
+ del labels[-1]
45
+ if labels:
46
+ trailing_dot = b'.'
47
+
48
+ result = []
49
+ size = 0
50
+ for label in labels:
51
+ result.append(alabel(label))
52
+ if size:
53
+ size += 1
54
+ size += len(label)
55
+
56
+ # Join with U+002E
57
+ result_bytes = b'.'.join(result) + trailing_dot
58
+ size += len(trailing_dot)
59
+ return result_bytes, size
60
+
61
+ class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
62
+ def _buffer_decode(self, data: Any, errors: str, final: bool) -> Tuple[str, int]:
63
+ if errors != 'strict':
64
+ raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
65
+
66
+ if not data:
67
+ return ('', 0)
68
+
69
+ if not isinstance(data, str):
70
+ data = str(data, 'ascii')
71
+
72
+ labels = _unicode_dots_re.split(data)
73
+ trailing_dot = ''
74
+ if labels:
75
+ if not labels[-1]:
76
+ trailing_dot = '.'
77
+ del labels[-1]
78
+ elif not final:
79
+ # Keep potentially unfinished label until the next call
80
+ del labels[-1]
81
+ if labels:
82
+ trailing_dot = '.'
83
+
84
+ result = []
85
+ size = 0
86
+ for label in labels:
87
+ result.append(ulabel(label))
88
+ if size:
89
+ size += 1
90
+ size += len(label)
91
+
92
+ result_str = '.'.join(result) + trailing_dot
93
+ size += len(trailing_dot)
94
+ return (result_str, size)
95
+
96
+
97
+ class StreamWriter(Codec, codecs.StreamWriter):
98
+ pass
99
+
100
+
101
+ class StreamReader(Codec, codecs.StreamReader):
102
+ pass
103
+
104
+
105
+ def search_function(name: str) -> Optional[codecs.CodecInfo]:
106
+ if name != 'idna2008':
107
+ return None
108
+ return codecs.CodecInfo(
109
+ name=name,
110
+ encode=Codec().encode,
111
+ decode=Codec().decode,
112
+ incrementalencoder=IncrementalEncoder,
113
+ incrementaldecoder=IncrementalDecoder,
114
+ streamwriter=StreamWriter,
115
+ streamreader=StreamReader,
116
+ )
117
+
118
+ codecs.register(search_function)