ag14850 commited on
Commit
82da8fb
·
verified ·
1 Parent(s): ef716ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -2
app.py CHANGED
@@ -1,7 +1,106 @@
1
  import gradio as gr
2
- from transformers import T5ForConditionalGeneration, AutoTokenizer
 
 
 
 
 
 
3
 
4
- model = T5ForConditionalGeneration.from_pretrained("ag14850/Mosquito")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-base", legacy=False)
6
 
7
  def ask(question):
 
1
  import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import struct
5
+ import lzma
6
+ import json
7
+ from huggingface_hub import hf_hub_download
8
+ from transformers import T5Config, T5ForConditionalGeneration, AutoTokenizer
9
 
10
+ # Download quantized model
11
+ model_path = hf_hub_download(repo_id="ag14850/Mosquito", filename="mosquito_tiny.bin.xz")
12
+
13
+ def unpack_nbits(data, bits, count):
14
+ if bits == 8:
15
+ return np.frombuffer(data, dtype=np.uint8)[:count]
16
+ result = []
17
+ if bits == 4:
18
+ for byte in data:
19
+ result.append((byte >> 4) & 0x0F)
20
+ result.append(byte & 0x0F)
21
+ elif bits == 6:
22
+ for i in range(0, len(data), 3):
23
+ if i + 2 >= len(data):
24
+ break
25
+ b0, b1, b2 = data[i], data[i+1], data[i+2]
26
+ result.append((b0 >> 2) & 0x3F)
27
+ result.append(((b0 & 0x03) << 4) | ((b1 >> 4) & 0x0F))
28
+ result.append(((b1 & 0x0F) << 2) | ((b2 >> 6) & 0x03))
29
+ result.append(b2 & 0x3F)
30
+ elif bits == 5:
31
+ for i in range(0, len(data), 5):
32
+ if i + 4 >= len(data):
33
+ break
34
+ packed = int.from_bytes(data[i:i+5], 'little')
35
+ for j in range(8):
36
+ result.append((packed >> (j * 5)) & 0x1F)
37
+ elif bits == 7:
38
+ for i in range(0, len(data), 7):
39
+ if i + 6 >= len(data):
40
+ break
41
+ packed = int.from_bytes(data[i:i+7], 'little')
42
+ for j in range(8):
43
+ result.append((packed >> (j * 7)) & 0x7F)
44
+ return np.array(result[:count], dtype=np.uint8)
45
+
46
+ def load_quantized_model(path):
47
+ with lzma.open(path, 'rb') as f:
48
+ data = f.read()
49
+
50
+ offset = 0
51
+ version, default_bits, num_params = struct.unpack_from('<BBH', data, offset)
52
+ offset += 4
53
+
54
+ state_dict = {}
55
+
56
+ for _ in range(num_params):
57
+ name_len = struct.unpack_from('<H', data, offset)[0]
58
+ offset += 2
59
+ name = data[offset:offset + name_len].decode('utf-8')
60
+ offset += name_len
61
+
62
+ ndim = struct.unpack_from('<B', data, offset)[0]
63
+ offset += 1
64
+ shape = tuple(struct.unpack_from('<I', data, offset + i*4)[0] for i in range(ndim))
65
+ offset += ndim * 4
66
+ numel = int(np.prod(shape)) if shape else 1
67
+
68
+ bits = struct.unpack_from('<B', data, offset)[0]
69
+ offset += 1
70
+
71
+ if bits < 16:
72
+ scale, zp = struct.unpack_from('<ff', data, offset)
73
+ offset += 8
74
+ packed_len = struct.unpack_from('<I', data, offset)[0]
75
+ offset += 4
76
+ packed_data = data[offset:offset + packed_len]
77
+ offset += packed_len
78
+
79
+ quantized = unpack_nbits(packed_data, bits, numel)
80
+ tensor_data = ((quantized.astype(np.float32) - zp) * scale).reshape(shape)
81
+ state_dict[name] = torch.from_numpy(tensor_data)
82
+ else:
83
+ fp16_len = struct.unpack_from('<I', data, offset)[0]
84
+ offset += 4
85
+ fp16_data = data[offset:offset + fp16_len]
86
+ offset += fp16_len
87
+
88
+ tensor_data = np.frombuffer(fp16_data, dtype=np.float16).reshape(shape)
89
+ state_dict[name] = torch.from_numpy(tensor_data.astype(np.float32))
90
+
91
+ config_len = struct.unpack_from('<I', data, offset)[0]
92
+ offset += 4
93
+ config_json = data[offset:offset + config_len].decode('utf-8')
94
+
95
+ config = T5Config.from_dict(json.loads(config_json))
96
+ model = T5ForConditionalGeneration(config)
97
+ model.load_state_dict(state_dict)
98
+ model.eval()
99
+
100
+ return model
101
+
102
+ # Load model
103
+ model = load_quantized_model(model_path)
104
  tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-base", legacy=False)
105
 
106
  def ask(question):