|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import random |
|
|
import re |
|
|
import json |
|
|
import math |
|
|
import copy |
|
|
import requests |
|
|
from functools import lru_cache |
|
|
from tqdm import trange |
|
|
from torch.nn.parameter import Parameter |
|
|
from sklearn.datasets import fetch_20newsgroups |
|
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
|
from sklearn.linear_model import LogisticRegression |
|
|
import time |
|
|
import threading |
|
|
import queue |
|
|
from deep_translator import GoogleTranslator |
|
|
from flask import Flask, request, jsonify |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import uuid |
|
|
import wget |
|
|
|
|
|
MODEL_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin" |
|
|
ENCODER_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/encoder.json" |
|
|
VOCAB_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/vocab.bpe" |
|
|
GPT2_FOLDER = "./GPT2" |
|
|
MODEL_FILE = "gpt2-pytorch_model.bin" |
|
|
ENCODER_FILE = "encoder.json" |
|
|
VOCAB_FILE = "vocab.bpe" |
|
|
TEXT_GENERATION_RATE = 40000 |
|
|
html_code = """<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><title>AI Text Generation Platform</title><link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/><link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" integrity="sha512-9usAa10IRO0HhonpyAIVpjrylPvoDwiPUiKdWk5t3PyolY1cOd4DSE0Ga+ri4AuTroPR5aQvXU9xC6qOPnzFeg==" crossorigin="anonymous" referrerpolicy="no-referrer" /><style>body {font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;background: linear-gradient(to right, #667eea, #764ba2);color: #fff;margin: 0;padding: 0;overflow: hidden;display: flex;justify-content: center;align-items: center;min-height: 100vh;}.container {width: 80%;max-width: 900px;margin: 50px auto;background: rgba(255, 255, 255, 0.1);padding: 40px;border-radius: 15px;box-shadow: 0 5px 25px rgba(0, 0, 0, 0.2);backdrop-filter: blur(10px);border: 1px solid rgba(255, 255, 255, 0.2);transition: transform 0.3s ease-in-out;}.container:hover {transform: translateY(-5px);}.header {text-align: center;margin-bottom: 40px;color: #fff;}.header h1 {font-size: 3.5em;font-weight: 600;letter-spacing: 1px;text-shadow: 0 2px 5px rgba(0, 0, 0, 0.3);}.header p {font-size: 1.2em;opacity: 0.8;text-shadow: 0 1px 3px rgba(0, 0, 0, 0.3);}.form-group {margin-bottom: 30px;}.form-group label {display: block;margin-bottom: 10px;font-size: 1.1em;font-weight: 500;opacity: 0.9;text-shadow: 0 1px 2px rgba(0, 0, 0, 0.3);}.form-group textarea, .form-group input[type="text"] {width: 100%;padding: 15px;border: none;border-radius: 8px;background: rgba(255, 255, 255, 0.2);color: #fff;font-size: 16px;box-shadow: inset 0 2px 8px rgba(0, 0, 0, 0.2);backdrop-filter: blur(8px);transition: all 0.3s ease;}.form-group textarea:focus, .form-group input[type="text"]:focus {outline: none;background: rgba(255, 255, 255, 0.3);box-shadow: inset 0 3px 12px rgba(0, 0, 0, 0.3);}button {width: 100%;padding: 15px;border: none;border-radius: 8px;background: linear-gradient(to right, #43cea2, #185a9d);color: #fff;font-size: 18px;font-weight: 500;cursor: pointer;box-shadow: 0 3px 10px rgba(0, 0, 0, 0.3);transition: background 0.3s ease;position: relative;overflow: hidden;}button:hover {background: linear-gradient(to left, #43cea2, #185a9d);}button::before {content: '\\f021';font-family: FontAwesome;position: absolute;top: 0;left: -30px;width: 30px;height: 100%;background: rgba(255, 255, 255, 0.3);display: flex;align-items: center;justify-content: center;transition: left 0.4s ease;}button:hover::before {left: 100%;transform: translateX(-100%);}.animated-text {position: absolute;top: 50%;left: 50%;transform: translate(-50%, -50%);font-size: 6em;font-weight: bold;color: rgba(255, 255, 255, 0.05);pointer-events: none;z-index: -1;}#output {margin-top: 40px;padding: 25px;border-radius: 10px;background: rgba(255, 255, 255, 0.1);box-shadow: 0 4px 12px rgba(0, 0, 0, 0.2);backdrop-filter: blur(10px);border: 1px solid rgba(255, 255, 255, 0.2);white-space: pre-wrap;word-break: break-word;opacity: 0;transition: opacity 1s ease, transform 0.5s ease;transform: translateY(20px);}#output.show {opacity: 1;transform: translateY(0);}#output strong {color: #93b8c2;font-weight: 600;}@keyframes fadeIn {from {opacity: 0;transform: translateY(-20px);}to {opacity: 1;transform: translateY(0);}}</style></head><body><div class="animated-text animate__animated animate__fadeIn animate__infinite infinite">AI POWERED</div><div class="container"><div class="header animate__animated animate__fadeInDown"><h1>AI Text Generation Platform</h1><p>Unleash your creativity with our advanced text generation tool.</p></div><div class="form-group animate__animated animate__fadeInLeft"><label for="text">Input Text:</label><textarea id="text" rows="5" placeholder="Enter your text here"></textarea></div><div class="form-group animate__animated animate__fadeInRight"><label for="length">Length:</label><input type="text" id="length" value="50" placeholder="Enter desired length"></div><div class="form-group animate__animated animate__fadeInLeft"><label for="temperature">Temperature:</label><input type="text" id="temperature" value="0.7" placeholder="Enter temperature (0-1)"></div><div class="form-group animate__animated animate__fadeInRight"><label for="top_k">Top K:</label><input type="text" id="top_k" value="40" placeholder="Enter top K value"></div><button onclick="generateText()" class="animate__animated animate__fadeInUp">Generate Text</button><div id="output" class="animate__animated"><strong >Generated Text:</strong><br><span id="generatedText"></span><br><br><strong >Reasoning:</strong><br><span id="reasoning"></span><br><br><strong >Category:</strong><br><span id="category"></span></div></div><script>async function generateText() {const inputText = document.getElementById("text").value;const length = parseInt(document.getElementById("length").value);const temperature = parseFloat(document.getElementById("temperature").value);const top_k = parseInt(document.getElementById("top_k").value);const data = {text: inputText,length: length,temperature: temperature,top_k: top_k};const response = await fetch('/generate', {method: 'POST',headers: {'Content-Type': 'application/json'},body: JSON.stringify(data)});const result = await response.json();document.getElementById("generatedText").innerText = result.generated_text;document.getElementById("reasoning").innerText = result.reasoning;document.getElementById("category").innerText = result.category;const outputDiv = document.getElementById("output");outputDiv.classList.add("show");applyFeedback(inputText, result.category)}async function applyFeedback(originalText, predictedCategory) {const correctCategory = prompt("Is category correct? Enter correct category or leave blank:");if (correctCategory) {const feedbackData = {text: originalText,category: correctCategory};await fetch('/generate', {method: 'POST',headers: {'Content-Type': 'application/json'},body: JSON.stringify(feedbackData)})}}</script></body></html>""" |
|
|
|
|
|
class GPT2Config: |
|
|
def __init__(self, vocab_size_or_config_json_file=50257, n_positions=8, n_ctx=8, n_embd=128, n_layer=1, n_head=1, layer_norm_epsilon=1e-5, initializer_range=0.02): |
|
|
self.vocab_size = vocab_size_or_config_json_file |
|
|
self.n_ctx = n_ctx |
|
|
self.n_positions = n_positions |
|
|
self.n_embd = n_embd |
|
|
self.n_layer = n_layer |
|
|
self.n_head = n_head |
|
|
self.layer_norm_epsilon = layer_norm_epsilon |
|
|
self.initializer_range = initializer_range |
|
|
|
|
|
@lru_cache() |
|
|
def bytes_to_unicode(): |
|
|
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) |
|
|
cs = bs[:] |
|
|
n = 0 |
|
|
for b in range(2**8): |
|
|
if b not in bs: |
|
|
bs.append(b) |
|
|
cs.append(2**8 + n) |
|
|
n += 1 |
|
|
cs = [chr(n) for n in cs] |
|
|
return dict(zip(bs, cs)) |
|
|
|
|
|
def get_pairs(word): |
|
|
pairs = set() |
|
|
prev_char = word[0] |
|
|
for char in word[1:]: |
|
|
pairs.add((prev_char, char)) |
|
|
prev_char = char |
|
|
return pairs |
|
|
|
|
|
class Encoder: |
|
|
def __init__(self, encoder, bpe_merges, errors='replace'): |
|
|
self.encoder = encoder |
|
|
self.decoder = {v:k for k,v in self.encoder.items()} |
|
|
self.errors = errors |
|
|
self.byte_encoder = bytes_to_unicode() |
|
|
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} |
|
|
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) |
|
|
self.cache = {} |
|
|
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE) |
|
|
|
|
|
def bpe(self, token): |
|
|
if token in self.cache: |
|
|
return self.cache[token] |
|
|
word = tuple(token) |
|
|
pairs = get_pairs(word) |
|
|
if not pairs: |
|
|
return token |
|
|
while True: |
|
|
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) |
|
|
if bigram not in self.bpe_ranks: |
|
|
break |
|
|
first, second = bigram |
|
|
new_word = [] |
|
|
i = 0 |
|
|
while i < len(word): |
|
|
try: |
|
|
j = word.index(first, i) |
|
|
new_word.extend(word[i:j]) |
|
|
i = j |
|
|
except: |
|
|
new_word.extend(word[i:]) |
|
|
break |
|
|
if word[i] == first and i < len(word)-1 and word[i+1] == second: |
|
|
new_word.append(first+second) |
|
|
i += 2 |
|
|
else: |
|
|
new_word.append(word[i]) |
|
|
i += 1 |
|
|
new_word = tuple(new_word) |
|
|
word = new_word |
|
|
if len(word) == 1: |
|
|
break |
|
|
else: |
|
|
pairs = get_pairs(word) |
|
|
word = ' '.join(word) |
|
|
self.cache[token] = word |
|
|
return word |
|
|
|
|
|
def encode(self, text): |
|
|
bpe_tokens = [] |
|
|
for token in re.findall(self.pat, text): |
|
|
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) |
|
|
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) |
|
|
return bpe_tokens |
|
|
|
|
|
def decode(self, tokens): |
|
|
text = ''.join([self.decoder[token] for token in tokens]) |
|
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) |
|
|
return text |
|
|
|
|
|
def get_encoder(): |
|
|
encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE) |
|
|
vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE) |
|
|
with open(encoder_path, 'r') as f: |
|
|
encoder = json.load(f) |
|
|
with open(vocab_path, 'r', encoding="utf-8") as f: |
|
|
bpe_data = f.read() |
|
|
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] |
|
|
return Encoder(encoder=encoder, bpe_merges=bpe_merges) |
|
|
|
|
|
def gelu(x): |
|
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
|
def __init__(self, hidden_size, eps=1e-12): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.bias = nn.Parameter(torch.zeros(hidden_size)) |
|
|
self.variance_epsilon = eps |
|
|
|
|
|
def forward(self, x): |
|
|
u = x.mean(-1, keepdim=True) |
|
|
s = (x - u).pow(2).mean(-1, keepdim=True) |
|
|
x = (x - u) / torch.sqrt(s + self.variance_epsilon) |
|
|
return self.weight * x + self.bias |
|
|
|
|
|
class Conv1D(nn.Module): |
|
|
def __init__(self, nf, nx): |
|
|
super().__init__() |
|
|
self.nf = nf |
|
|
w = torch.empty(nx, nf) |
|
|
nn.init.normal_(w, std=0.02) |
|
|
self.weight = Parameter(w) |
|
|
self.bias = Parameter(torch.zeros(nf)) |
|
|
|
|
|
def forward(self, x): |
|
|
size_out = x.size()[:-1] + (self.nf,) |
|
|
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) |
|
|
x = x.view(*size_out) |
|
|
return x |
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, nx, n_ctx, config, scale=False): |
|
|
super().__init__() |
|
|
n_state = nx |
|
|
assert n_state % config.n_head == 0 |
|
|
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) |
|
|
self.n_head = config.n_head |
|
|
self.split_size = n_state |
|
|
self.scale = scale |
|
|
self.c_attn = Conv1D(n_state * 3, nx) |
|
|
self.c_proj = Conv1D(n_state, nx) |
|
|
|
|
|
def _attn(self, q, k, v): |
|
|
w = torch.matmul(q, k) |
|
|
if self.scale: |
|
|
w = w / math.sqrt(v.size(-1)) |
|
|
nd, ns = w.size(-2), w.size(-1) |
|
|
b = self.bias[:, :, ns-nd:ns, :ns] |
|
|
w = w * b - 1e10 * (1 - b) |
|
|
w = nn.Softmax(dim=-1)(w) |
|
|
return torch.matmul(w, v) |
|
|
|
|
|
def merge_heads(self, x): |
|
|
x = x.permute(0, 2, 1, 3).contiguous() |
|
|
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) |
|
|
return x.view(*new_x_shape) |
|
|
|
|
|
def split_heads(self, x, k=False): |
|
|
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) |
|
|
x = x.view(*new_x_shape) |
|
|
if k: |
|
|
return x.permute(0, 2, 3, 1) |
|
|
else: |
|
|
return x.permute(0, 2, 1, 3) |
|
|
|
|
|
def forward(self, x, layer_past=None): |
|
|
x = self.c_attn(x) |
|
|
query, key, value = x.split(self.split_size, dim=2) |
|
|
query = self.split_heads(query) |
|
|
key = self.split_heads(key, k=True) |
|
|
value = self.split_heads(value) |
|
|
if layer_past is not None: |
|
|
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] |
|
|
key = torch.cat((past_key, key), dim=-1) |
|
|
value = torch.cat((past_value, value), dim=-2) |
|
|
present = torch.stack((key.transpose(-2, -1), value)) |
|
|
a = self._attn(query, key, value) |
|
|
a = self.merge_heads(a) |
|
|
a = self.c_proj(a) |
|
|
return a, present |
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, n_state, config): |
|
|
super().__init__() |
|
|
nx = config.n_embd |
|
|
self.c_fc = Conv1D(n_state, nx) |
|
|
self.c_proj = Conv1D(nx, n_state) |
|
|
self.act = gelu |
|
|
|
|
|
def forward(self, x): |
|
|
h = self.act(self.c_fc(x)) |
|
|
h2 = self.c_proj(h) |
|
|
return h2 |
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, n_ctx, config, scale=False): |
|
|
super().__init__() |
|
|
nx = config.n_embd |
|
|
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) |
|
|
self.attn = Attention(nx, n_ctx, config, scale) |
|
|
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) |
|
|
self.mlp = MLP(4 * nx, config) |
|
|
|
|
|
def forward(self, x, layer_past=None): |
|
|
a, present = self.attn(self.ln_1(x), layer_past=layer_past) |
|
|
x = x + a |
|
|
m = self.mlp(self.ln_2(x)) |
|
|
x = x + m |
|
|
return x, present |
|
|
|
|
|
class GPT2Model(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.n_layer = config.n_layer |
|
|
self.n_embd = config.n_embd |
|
|
self.n_vocab = config.vocab_size |
|
|
self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
|
|
self.wpe = nn.Embedding(config.n_positions, config.n_embd) |
|
|
block = Block(config.n_ctx, config, scale=True) |
|
|
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) |
|
|
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
|
|
|
|
def set_embeddings_weights(self, model_embeddings_weights): |
|
|
embed_shape = model_embeddings_weights.shape |
|
|
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) |
|
|
self.decoder.weight = model_embeddings_weights |
|
|
|
|
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None): |
|
|
if past is None: |
|
|
past_length = 0 |
|
|
past = [None] * len(self.h) |
|
|
else: |
|
|
past_length = past[0][0].size(-2) |
|
|
if position_ids is None: |
|
|
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) |
|
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) |
|
|
|
|
|
input_shape = input_ids.size() |
|
|
input_ids = input_ids.view(-1, input_ids.size(-1)) |
|
|
position_ids = position_ids.view(-1, position_ids.size(-1)) |
|
|
|
|
|
inputs_embeds = self.wte(input_ids) |
|
|
position_embeds = self.wpe(position_ids) |
|
|
if token_type_ids is not None: |
|
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) |
|
|
token_type_embeds = self.wte(token_type_ids) |
|
|
else: |
|
|
token_type_embeds = 0 |
|
|
hidden_states = inputs_embeds + position_embeds + token_type_embeds |
|
|
presents = [] |
|
|
for block, layer_past in zip(self.h, past): |
|
|
hidden_states, present = block(hidden_states, layer_past) |
|
|
presents.append(present) |
|
|
hidden_states = self.ln_f(hidden_states) |
|
|
output_shape = input_shape + (hidden_states.size(-1),) |
|
|
return hidden_states.view(*output_shape), presents |
|
|
|
|
|
class GPT2LMHead(nn.Module): |
|
|
def __init__(self, model_embeddings_weights, config): |
|
|
super().__init__() |
|
|
self.n_embd = config.n_embd |
|
|
self.set_embeddings_weights(model_embeddings_weights) |
|
|
|
|
|
def set_embeddings_weights(self, model_embeddings_weights): |
|
|
embed_shape = model_embeddings_weights.shape |
|
|
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) |
|
|
self.decoder.weight = model_embeddings_weights |
|
|
|
|
|
def forward(self, hidden_state): |
|
|
lm_logits = self.decoder(hidden_state) |
|
|
return lm_logits |
|
|
|
|
|
class GPT2LMHeadModel(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.transformer = GPT2Model(config) |
|
|
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) |
|
|
|
|
|
def set_tied(self): |
|
|
self.lm_head.set_embeddings_weights(self.transformer.wte.weight) |
|
|
|
|
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): |
|
|
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) |
|
|
lm_logits = self.lm_head(hidden_states) |
|
|
if lm_labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-1) |
|
|
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)) |
|
|
return loss |
|
|
return lm_logits, presents |
|
|
|
|
|
def top_k_logits(logits, k): |
|
|
if k == 0: |
|
|
return logits |
|
|
values, _ = torch.topk(logits, k) |
|
|
min_values = values[:, -1] |
|
|
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits) |
|
|
|
|
|
def sample_sequence(model, length, start_token=None, context=None, temperature=1, top_k=0, device='cuda', sample=True): |
|
|
if start_token is None: |
|
|
assert context is not None, 'Specify exactly one of start_token and context!' |
|
|
context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0) |
|
|
else: |
|
|
assert context is None, 'Specify exactly one of start_token and context!' |
|
|
context = torch.full((1, 1), start_token=start_token, device=device, dtype=torch.long) |
|
|
prev = context |
|
|
output = context |
|
|
past = None |
|
|
with torch.no_grad(): |
|
|
for _ in trange(length): |
|
|
logits, past = model(prev, past=past) |
|
|
logits = logits[:, -1, :] / temperature |
|
|
logits = top_k_logits(logits, k=top_k) |
|
|
log_probs = F.softmax(logits, dim=-1) |
|
|
if sample: |
|
|
prev = torch.multinomial(log_probs, num_samples=1) |
|
|
else: |
|
|
_, prev = torch.topk(log_probs, k=1, dim=-1) |
|
|
output = torch.cat((output, prev), dim=1) |
|
|
return output |
|
|
|
|
|
def load_weight(model, state_dict): |
|
|
old_keys = [] |
|
|
new_keys = [] |
|
|
for key in state_dict.keys(): |
|
|
new_key = None |
|
|
if key.endswith(".g"): |
|
|
new_key = key[:-2] + ".weight" |
|
|
elif key.endswith(".b"): |
|
|
new_key = key[:-2] + ".bias" |
|
|
elif key.endswith(".w"): |
|
|
new_key = key[:-2] + ".weight" |
|
|
if new_key: |
|
|
old_keys.append(key) |
|
|
new_keys.append(new_key) |
|
|
for old_key, new_key in zip(old_keys, new_keys): |
|
|
state_dict[new_key] = state_dict.pop(old_key) |
|
|
|
|
|
missing_keys = [] |
|
|
unexpected_keys = [] |
|
|
error_msgs = [] |
|
|
metadata = getattr(state_dict, "_metadata", None) |
|
|
state_dict = state_dict.copy() |
|
|
if metadata is not None: |
|
|
state_dict._metadata = metadata |
|
|
|
|
|
def load(module, prefix=""): |
|
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) |
|
|
module._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) |
|
|
for name, child in module._modules.items(): |
|
|
if child is not None: |
|
|
load(child, prefix + name + ".") |
|
|
|
|
|
start_model = model |
|
|
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()): |
|
|
start_model = model.transformer |
|
|
load(start_model, prefix="") |
|
|
model.set_tied() |
|
|
return model |
|
|
|
|
|
def download_file(url, filename, retries=3): |
|
|
for attempt in range(retries): |
|
|
try: |
|
|
wget.download(url, out=filename) |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Download error on attempt {attempt + 1}: {e}") |
|
|
if attempt < retries - 1: |
|
|
time.sleep(2) |
|
|
else: |
|
|
print(f"Failed to download {url} after {retries} attempts.") |
|
|
return False |
|
|
return False |
|
|
|
|
|
|
|
|
def ensure_files_exist(): |
|
|
global MODEL_URL, ENCODER_URL, VOCAB_URL |
|
|
if not os.path.exists(GPT2_FOLDER): |
|
|
os.makedirs(GPT2_FOLDER) |
|
|
model_path = os.path.join(GPT2_FOLDER, MODEL_FILE) |
|
|
encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE) |
|
|
vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE) |
|
|
if not os.path.exists(model_path): |
|
|
print(f"Downloading model from {MODEL_URL} to {model_path}") |
|
|
if not download_file(MODEL_URL, model_path): |
|
|
print("Exiting due to model download failure.") |
|
|
sys.exit(1) |
|
|
if not os.path.exists(encoder_path): |
|
|
print(f"Downloading encoder from {ENCODER_URL} to {encoder_path}") |
|
|
if not download_file(ENCODER_URL, encoder_path): |
|
|
print("Exiting due to encoder download failure.") |
|
|
sys.exit(1) |
|
|
if not os.path.exists(vocab_path): |
|
|
print(f"Downloading vocab from {VOCAB_URL} to {vocab_path}") |
|
|
if not download_file(VOCAB_URL, vocab_path): |
|
|
print("Exiting due to vocab download failure.") |
|
|
sys.exit(1) |
|
|
|
|
|
def translate_text(text, target_language='es', delay=1): |
|
|
try: |
|
|
translator = GoogleTranslator(source='auto', target=target_language) |
|
|
translated_text = translator.translate(text) |
|
|
time.sleep(delay) |
|
|
return translated_text |
|
|
except Exception as e: |
|
|
print(f"Translation Error: {e}") |
|
|
return text |
|
|
|
|
|
state_dict = None |
|
|
enc = None |
|
|
config = None |
|
|
model = None |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
news_clf = None |
|
|
tfidf_vectorizer = None |
|
|
text_queue = queue.Queue() |
|
|
categories = None |
|
|
is_training = False |
|
|
background_threads = [] |
|
|
feedback_queue = queue.Queue() |
|
|
|
|
|
def initialize_model(): |
|
|
global state_dict, enc, config, model, device, GPT2_FOLDER, MODEL_FILE |
|
|
if state_dict is None: |
|
|
ensure_files_exist() |
|
|
model_path = os.path.join(GPT2_FOLDER, MODEL_FILE) |
|
|
state_dict = torch.load(model_path, map_location=device) |
|
|
enc = get_encoder() |
|
|
config = GPT2Config() |
|
|
model = GPT2LMHeadModel(config).to(device) |
|
|
model = load_weight(model, state_dict) |
|
|
model.eval() |
|
|
|
|
|
def perform_reasoning(text): |
|
|
if model is None: |
|
|
initialize_model() |
|
|
try: |
|
|
reasoning_prompt = f"Given: '{text}', what's inferred?" |
|
|
context_tokens = enc.encode(reasoning_prompt) |
|
|
out = sample_sequence(model=model, length=100, context=context_tokens, temperature=0.7, top_k=40, device=device) |
|
|
out = out[:, len(context_tokens):].tolist() |
|
|
return enc.decode(out[0]) |
|
|
except Exception as e: |
|
|
print(f"Reasoning Error: {e}") |
|
|
return "" |
|
|
|
|
|
def get_category(text): |
|
|
global news_clf, categories, tfidf_vectorizer |
|
|
if news_clf is None or tfidf_vectorizer is None or categories is None: |
|
|
return "Not initialized" |
|
|
try: |
|
|
tfidf_matrix = tfidf_vectorizer.transform([text]) |
|
|
predicted_category_index = news_clf.predict(tfidf_matrix)[0] |
|
|
if 0 <= predicted_category_index < len(categories): |
|
|
return categories[predicted_category_index] |
|
|
else: |
|
|
print(f"Error: Predicted category index {predicted_category_index} out of bounds (categories size: {len(categories)}). Returning 'Unknown Category'.") |
|
|
return "Unknown Category" |
|
|
except IndexError as e: |
|
|
print(f"IndexError in get_category: {e}. Categories: {categories}. Returning 'Error Category'.") |
|
|
return "Error Category" |
|
|
except Exception as e: |
|
|
print(f"Category Prediction Error: {e}") |
|
|
return "Error" |
|
|
|
|
|
|
|
|
def generate_and_queue_text(language): |
|
|
global text_queue, categories |
|
|
if categories is None: |
|
|
print("Categories not initialized.") |
|
|
return |
|
|
num_categories = len(categories) |
|
|
num_texts_per_category = TEXT_GENERATION_RATE // (2 * num_categories) |
|
|
while True: |
|
|
for category in categories: |
|
|
for _ in range(num_texts_per_category): |
|
|
uid = uuid.uuid4() |
|
|
base_text = f"Category: {category}. ID:{uid}" |
|
|
text = translate_text(base_text, target_language=language, delay=2) |
|
|
text_queue.put((text, category)) |
|
|
time.sleep(0) |
|
|
|
|
|
def background_training(): |
|
|
global news_clf, tfidf_vectorizer, text_queue, is_training, feedback_queue, model, state_dict, enc, config, device, categories |
|
|
if is_training: |
|
|
print("Training running.") |
|
|
return |
|
|
is_training = True |
|
|
try: |
|
|
newsgroups = fetch_20newsgroups(subset='train') |
|
|
categories = newsgroups.target_names |
|
|
tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=10000) |
|
|
tfidf_matrix = tfidf_vectorizer.fit_transform(newsgroups.data) |
|
|
news_clf = LogisticRegression(random_state=42, solver='liblinear', multi_class='ovr') |
|
|
news_clf.fit(tfidf_matrix, newsgroups.target) |
|
|
print("Initial training done.") |
|
|
while True: |
|
|
try: |
|
|
text, category = text_queue.get() |
|
|
try: |
|
|
tfidf_matrix_predict = tfidf_vectorizer.transform([text]) |
|
|
predicted_category_index_predict = news_clf.predict(tfidf_matrix_predict)[0] |
|
|
new_texts = [text] |
|
|
new_labels = [predicted_category_index_predict] |
|
|
new_tfidf_matrix = tfidf_vectorizer.transform(new_texts) |
|
|
try: |
|
|
news_clf.fit(new_tfidf_matrix, new_labels) |
|
|
except ValueError as e: |
|
|
print(f"ValueError during news_clf.fit: {e}. Skip batch.") |
|
|
except IndexError as index_error: |
|
|
print(f"IndexError during category prediction or fitting: {index_error}. Text: {text}, Category: {category}") |
|
|
except Exception as e_pred: |
|
|
print(f"Error during category prediction or fitting: {e_pred}. Text: {text}, Category: {category}") |
|
|
|
|
|
|
|
|
try: |
|
|
feedback_text, correct_category = feedback_queue.get(timeout=0.1) |
|
|
tfidf_matrix_feedback = tfidf_vectorizer.transform([feedback_text]) |
|
|
if correct_category in categories: |
|
|
correct_category_index = categories.index(correct_category) |
|
|
else: |
|
|
print(f"Invalid feedback category: {correct_category}. Skipping feedback.") |
|
|
continue |
|
|
new_texts = [feedback_text] |
|
|
new_labels = [correct_category_index] |
|
|
new_tfidf_matrix_feedback = tfidf_vectorizer.transform(new_texts) |
|
|
news_clf.fit(new_tfidf_matrix_feedback, new_labels) |
|
|
|
|
|
try: |
|
|
model.train() |
|
|
context_tokens = enc.encode(feedback_text) |
|
|
input_ids = torch.tensor([context_tokens], device=device) |
|
|
outputs = model(input_ids, lm_labels=input_ids) |
|
|
loss = outputs[0] |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) |
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
model.eval() |
|
|
|
|
|
print("GPT-2 Fine-tuning iteration done.") |
|
|
torch.save(model.state_dict(), MODEL_FILE) |
|
|
|
|
|
except Exception as gpt_e: |
|
|
print(f"GPT-2 fine-tuning error {gpt_e}") |
|
|
|
|
|
except queue.Empty: |
|
|
pass |
|
|
except Exception as e: |
|
|
print(f"Feedback processing error {e}") |
|
|
|
|
|
time.sleep(0) |
|
|
except Exception as e: |
|
|
print(f"Training Error: {e}") |
|
|
finally: |
|
|
text_queue.task_done() |
|
|
except Exception as e: |
|
|
print(f"Fatal Training Error: {e}") |
|
|
finally: |
|
|
is_training = False |
|
|
|
|
|
def analyze_input(text): |
|
|
reasoning = perform_reasoning(text) |
|
|
category = get_category(text) |
|
|
return {"reasoning": reasoning, "category": category} |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
@app.route('/') |
|
|
def index(): |
|
|
return html_code |
|
|
|
|
|
@app.route('/generate', methods=['POST']) |
|
|
def generate(): |
|
|
data = request.get_json() |
|
|
text = data.get('text', "") |
|
|
length = data.get('length', 50) |
|
|
temperature = data.get('temperature', 0.7) |
|
|
top_k = data.get('top_k', 40) |
|
|
correct_category = data.get('category') |
|
|
|
|
|
generated_text = generate_text(text, length, temperature, top_k) |
|
|
analysis = analyze_input(text) |
|
|
|
|
|
if correct_category: |
|
|
feedback_queue.put((text, correct_category)) |
|
|
|
|
|
response_data = { |
|
|
"generated_text": generated_text, |
|
|
"reasoning": analysis["reasoning"], |
|
|
"category": analysis["category"] |
|
|
} |
|
|
return jsonify(response_data) |
|
|
|
|
|
def generate_text(text, length, temperature, top_k): |
|
|
if model is None: |
|
|
initialize_model() |
|
|
context_tokens = enc.encode(text) |
|
|
out = sample_sequence(model=model, length=length, context=context_tokens, temperature=temperature, top_k=top_k, device=device) |
|
|
out = out[:, len(context_tokens):].tolist() |
|
|
text = enc.decode(out[0]) |
|
|
return text |
|
|
|
|
|
def initialize_sklearn(): |
|
|
global news_clf, tfidf_vectorizer, categories |
|
|
try: |
|
|
newsgroups = fetch_20newsgroups(subset='train') |
|
|
categories = newsgroups.target_names |
|
|
tfidf_vectorizer = TfidfVectorizer(stop_words='english',max_features=10000) |
|
|
tfidf_matrix = tfidf_vectorizer.fit_transform(newsgroups.data) |
|
|
news_clf = LogisticRegression(random_state=42, solver='liblinear', multi_class='ovr') |
|
|
news_clf.fit(tfidf_matrix, newsgroups.target) |
|
|
print(f"Categories: {categories}") |
|
|
except Exception as e: |
|
|
print(f"Sklearn Error: {e}") |
|
|
return |
|
|
|
|
|
def run_app(): |
|
|
app.run(host='0.0.0.0', debug=True, threaded=True, port=7860, use_reloader=False) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
ensure_files_exist() |
|
|
initialize_model() |
|
|
initialize_sklearn() |
|
|
|
|
|
background_threads.append(threading.Thread(target=generate_and_queue_text, args=('en',), daemon=True)) |
|
|
background_threads.append(threading.Thread(target=generate_and_queue_text, args=('es',), daemon=True)) |
|
|
background_threads.append(threading.Thread(target=background_training, daemon=True)) |
|
|
background_threads.append(threading.Thread(target=run_app, daemon=True)) |
|
|
|
|
|
for thread in background_threads: |
|
|
thread.start() |
|
|
|
|
|
while True: |
|
|
time.sleep(1) |