|
|
import torch |
|
|
import torch.nn as nn |
|
|
from tokenizers import Tokenizer |
|
|
import re |
|
|
import argparse |
|
|
import sys |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StabilizedDenoisingModel(nn.Module): |
|
|
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers): |
|
|
super(StabilizedDenoisingModel, self).__init__() |
|
|
self.embedding = nn.Embedding(vocab_size, embed_dim) |
|
|
self.row_transform = nn.Linear(embed_dim, hidden_dim) |
|
|
self.dim_transform = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.norm = nn.LayerNorm(hidden_dim) |
|
|
|
|
|
self.denoise_layers = nn.ModuleList([ |
|
|
nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_dim, hidden_dim) |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
|
|
|
self.output_layer = nn.Linear(hidden_dim, vocab_size) |
|
|
|
|
|
def forward(self, input_seq): |
|
|
embedded_seq = self.embedding(input_seq) |
|
|
hidden_space = self.row_transform(embedded_seq) |
|
|
hidden_space = self.dim_transform(hidden_space) |
|
|
hidden_space = self.norm(hidden_space) |
|
|
|
|
|
for denoise_layer in self.denoise_layers: |
|
|
signal = denoise_layer(hidden_space) |
|
|
gate = torch.sigmoid(signal) |
|
|
denoised = hidden_space - gate * signal + (1 - gate) * torch.relu(signal) |
|
|
hidden_space = self.norm(hidden_space + denoised) |
|
|
|
|
|
logits = self.output_layer(hidden_space) |
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_text(text): |
|
|
"""清洗输入文本""" |
|
|
text = text.lower() |
|
|
text = re.sub(r'[^a-z0-9\s.,!?;:\'"-]', '', text) |
|
|
text = re.sub(r'\s+', ' ', text).strip() |
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stream_generate_text(model, tokenizer, device, start_text, max_len=100, temperature=0.8): |
|
|
"""流式生成文本,逐个token输出(修复输出问题)""" |
|
|
model.eval() |
|
|
|
|
|
|
|
|
start_text = clean_text(start_text) |
|
|
|
|
|
|
|
|
input_ids = tokenizer.encode(start_text).ids |
|
|
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device) |
|
|
|
|
|
generated_ids = input_ids.copy() |
|
|
|
|
|
|
|
|
last_output_length = len(start_text) |
|
|
|
|
|
|
|
|
print(start_text, end="", flush=True) |
|
|
|
|
|
for i in range(max_len): |
|
|
with torch.no_grad(): |
|
|
|
|
|
if input_tensor.size(1) > 100: |
|
|
input_tensor = input_tensor[:, -100:] |
|
|
|
|
|
|
|
|
logits = model(input_tensor) |
|
|
next_token_logits = logits[:, -1, :] / temperature |
|
|
probs = torch.softmax(next_token_logits, dim=-1) |
|
|
|
|
|
|
|
|
probs[probs < 0.01] = 0 |
|
|
probs = probs / probs.sum() |
|
|
|
|
|
|
|
|
next_token = torch.multinomial(probs, num_samples=1).item() |
|
|
|
|
|
|
|
|
if next_token == tokenizer.token_to_id("<SEP>"): |
|
|
break |
|
|
|
|
|
|
|
|
generated_ids.append(next_token) |
|
|
next_token_tensor = torch.tensor([[next_token]], device=device, dtype=torch.long) |
|
|
input_tensor = torch.cat([input_tensor, next_token_tensor], dim=1) |
|
|
|
|
|
|
|
|
current_text = tokenizer.decode(generated_ids) |
|
|
|
|
|
|
|
|
new_text = current_text[last_output_length:] |
|
|
last_output_length = len(current_text) |
|
|
|
|
|
|
|
|
print(new_text, end="", flush=True) |
|
|
|
|
|
|
|
|
return tokenizer.decode(generated_ids) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(model_path, tokenizer_path): |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") |
|
|
print(f"使用设备: {device}") |
|
|
|
|
|
|
|
|
tokenizer = Tokenizer.from_file(tokenizer_path) |
|
|
vocab_size = tokenizer.get_vocab_size() |
|
|
print(f"加载分词器成功,词汇表大小: {vocab_size}") |
|
|
|
|
|
|
|
|
model_params = { |
|
|
"vocab_size": vocab_size, |
|
|
"embed_dim": 256, |
|
|
"hidden_dim": 512, |
|
|
"num_layers": 16 |
|
|
} |
|
|
|
|
|
|
|
|
model = StabilizedDenoisingModel(**model_params).to(device) |
|
|
|
|
|
|
|
|
try: |
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
|
|
|
if 'model_state_dict' in checkpoint: |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
else: |
|
|
model.load_state_dict(checkpoint) |
|
|
|
|
|
print(f"加载模型成功: {model_path}") |
|
|
except Exception as e: |
|
|
print(f"模型加载失败: {str(e)}") |
|
|
return |
|
|
|
|
|
|
|
|
print("\n===== GTC-2 Large Base Model Text Generator (Early Research Preview) =====") |
|
|
print("输入文本后按回车生成,输入'quit'退出") |
|
|
|
|
|
while True: |
|
|
user_input = input("\n输入: ") |
|
|
if "activate" in user_input and "venv" in user_input: |
|
|
print("检测到虚拟环境激活命令,已忽略") |
|
|
continue |
|
|
|
|
|
if user_input.lower() == 'quit': |
|
|
break |
|
|
|
|
|
|
|
|
sys.stdout.flush() |
|
|
|
|
|
|
|
|
print("生成: ", end="", flush=True) |
|
|
generated_text = stream_generate_text( |
|
|
model, |
|
|
tokenizer, |
|
|
device, |
|
|
user_input, |
|
|
max_len=100, |
|
|
temperature=0.8 |
|
|
) |
|
|
|
|
|
print("\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser(description='GTC-2 Large Base Model 文本生成器') |
|
|
parser.add_argument('--model', type=str, default='best_model.pth', |
|
|
help='模型文件路径 (默认: best_model.pth)') |
|
|
parser.add_argument('--tokenizer', type=str, default='bpe_tokenizer.json', |
|
|
help='分词器文件路径 (默认: bpe_tokenizer.json)') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
main(args.model, args.tokenizer) |