Hhdvdbsb / app.py
jnjj's picture
Update app.py
c3f429d verified
import os
import sys
import torch
import random
import re
import json
import math
import copy
import requests
from functools import lru_cache
from tqdm import trange
from torch.nn.parameter import Parameter
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
import time
import threading
import queue
from deep_translator import GoogleTranslator
from flask import Flask, request, jsonify
import torch.nn as nn
import torch.nn.functional as F
import uuid
import wget # Import wget library
MODEL_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"
ENCODER_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/encoder.json"
VOCAB_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/vocab.bpe"
GPT2_FOLDER = "./GPT2"
MODEL_FILE = "gpt2-pytorch_model.bin"
ENCODER_FILE = "encoder.json"
VOCAB_FILE = "vocab.bpe"
TEXT_GENERATION_RATE = 40000
html_code = """<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><title>AI Text Generation Platform</title><link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/><link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" integrity="sha512-9usAa10IRO0HhonpyAIVpjrylPvoDwiPUiKdWk5t3PyolY1cOd4DSE0Ga+ri4AuTroPR5aQvXU9xC6qOPnzFeg==" crossorigin="anonymous" referrerpolicy="no-referrer" /><style>body {font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;background: linear-gradient(to right, #667eea, #764ba2);color: #fff;margin: 0;padding: 0;overflow: hidden;display: flex;justify-content: center;align-items: center;min-height: 100vh;}.container {width: 80%;max-width: 900px;margin: 50px auto;background: rgba(255, 255, 255, 0.1);padding: 40px;border-radius: 15px;box-shadow: 0 5px 25px rgba(0, 0, 0, 0.2);backdrop-filter: blur(10px);border: 1px solid rgba(255, 255, 255, 0.2);transition: transform 0.3s ease-in-out;}.container:hover {transform: translateY(-5px);}.header {text-align: center;margin-bottom: 40px;color: #fff;}.header h1 {font-size: 3.5em;font-weight: 600;letter-spacing: 1px;text-shadow: 0 2px 5px rgba(0, 0, 0, 0.3);}.header p {font-size: 1.2em;opacity: 0.8;text-shadow: 0 1px 3px rgba(0, 0, 0, 0.3);}.form-group {margin-bottom: 30px;}.form-group label {display: block;margin-bottom: 10px;font-size: 1.1em;font-weight: 500;opacity: 0.9;text-shadow: 0 1px 2px rgba(0, 0, 0, 0.3);}.form-group textarea, .form-group input[type="text"] {width: 100%;padding: 15px;border: none;border-radius: 8px;background: rgba(255, 255, 255, 0.2);color: #fff;font-size: 16px;box-shadow: inset 0 2px 8px rgba(0, 0, 0, 0.2);backdrop-filter: blur(8px);transition: all 0.3s ease;}.form-group textarea:focus, .form-group input[type="text"]:focus {outline: none;background: rgba(255, 255, 255, 0.3);box-shadow: inset 0 3px 12px rgba(0, 0, 0, 0.3);}button {width: 100%;padding: 15px;border: none;border-radius: 8px;background: linear-gradient(to right, #43cea2, #185a9d);color: #fff;font-size: 18px;font-weight: 500;cursor: pointer;box-shadow: 0 3px 10px rgba(0, 0, 0, 0.3);transition: background 0.3s ease;position: relative;overflow: hidden;}button:hover {background: linear-gradient(to left, #43cea2, #185a9d);}button::before {content: '\\f021';font-family: FontAwesome;position: absolute;top: 0;left: -30px;width: 30px;height: 100%;background: rgba(255, 255, 255, 0.3);display: flex;align-items: center;justify-content: center;transition: left 0.4s ease;}button:hover::before {left: 100%;transform: translateX(-100%);}.animated-text {position: absolute;top: 50%;left: 50%;transform: translate(-50%, -50%);font-size: 6em;font-weight: bold;color: rgba(255, 255, 255, 0.05);pointer-events: none;z-index: -1;}#output {margin-top: 40px;padding: 25px;border-radius: 10px;background: rgba(255, 255, 255, 0.1);box-shadow: 0 4px 12px rgba(0, 0, 0, 0.2);backdrop-filter: blur(10px);border: 1px solid rgba(255, 255, 255, 0.2);white-space: pre-wrap;word-break: break-word;opacity: 0;transition: opacity 1s ease, transform 0.5s ease;transform: translateY(20px);}#output.show {opacity: 1;transform: translateY(0);}#output strong {color: #93b8c2;font-weight: 600;}@keyframes fadeIn {from {opacity: 0;transform: translateY(-20px);}to {opacity: 1;transform: translateY(0);}}</style></head><body><div class="animated-text animate__animated animate__fadeIn animate__infinite infinite">AI POWERED</div><div class="container"><div class="header animate__animated animate__fadeInDown"><h1>AI Text Generation Platform</h1><p>Unleash your creativity with our advanced text generation tool.</p></div><div class="form-group animate__animated animate__fadeInLeft"><label for="text">Input Text:</label><textarea id="text" rows="5" placeholder="Enter your text here"></textarea></div><div class="form-group animate__animated animate__fadeInRight"><label for="length">Length:</label><input type="text" id="length" value="50" placeholder="Enter desired length"></div><div class="form-group animate__animated animate__fadeInLeft"><label for="temperature">Temperature:</label><input type="text" id="temperature" value="0.7" placeholder="Enter temperature (0-1)"></div><div class="form-group animate__animated animate__fadeInRight"><label for="top_k">Top K:</label><input type="text" id="top_k" value="40" placeholder="Enter top K value"></div><button onclick="generateText()" class="animate__animated animate__fadeInUp">Generate Text</button><div id="output" class="animate__animated"><strong >Generated Text:</strong><br><span id="generatedText"></span><br><br><strong >Reasoning:</strong><br><span id="reasoning"></span><br><br><strong >Category:</strong><br><span id="category"></span></div></div><script>async function generateText() {const inputText = document.getElementById("text").value;const length = parseInt(document.getElementById("length").value);const temperature = parseFloat(document.getElementById("temperature").value);const top_k = parseInt(document.getElementById("top_k").value);const data = {text: inputText,length: length,temperature: temperature,top_k: top_k};const response = await fetch('/generate', {method: 'POST',headers: {'Content-Type': 'application/json'},body: JSON.stringify(data)});const result = await response.json();document.getElementById("generatedText").innerText = result.generated_text;document.getElementById("reasoning").innerText = result.reasoning;document.getElementById("category").innerText = result.category;const outputDiv = document.getElementById("output");outputDiv.classList.add("show");applyFeedback(inputText, result.category)}async function applyFeedback(originalText, predictedCategory) {const correctCategory = prompt("Is category correct? Enter correct category or leave blank:");if (correctCategory) {const feedbackData = {text: originalText,category: correctCategory};await fetch('/generate', {method: 'POST',headers: {'Content-Type': 'application/json'},body: JSON.stringify(feedbackData)})}}</script></body></html>"""
class GPT2Config:
def __init__(self, vocab_size_or_config_json_file=50257, n_positions=8, n_ctx=8, n_embd=128, n_layer=1, n_head=1, layer_norm_epsilon=1e-5, initializer_range=0.02):
self.vocab_size = vocab_size_or_config_json_file
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
@lru_cache()
def bytes_to_unicode():
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges, errors='replace'):
self.encoder = encoder
self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def get_encoder():
encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE)
vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
with open(encoder_path, 'r') as f:
encoder = json.load(f)
with open(vocab_path, 'r', encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
return Encoder(encoder=encoder, bpe_merges=bpe_merges)
def gelu(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class Conv1D(nn.Module):
def __init__(self, nf, nx):
super().__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = Parameter(w)
self.bias = Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
return x
class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False):
super().__init__()
n_state = nx
assert n_state % config.n_head == 0
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx)
def _attn(self, q, k, v):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
nd, ns = w.size(-2), w.size(-1)
b = self.bias[:, :, ns-nd:ns, :ns]
w = w * b - 1e10 * (1 - b)
w = nn.Softmax(dim=-1)(w)
return torch.matmul(w, v)
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape)
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape)
if k:
return x.permute(0, 2, 3, 1)
else:
return x.permute(0, 2, 1, 3)
def forward(self, x, layer_past=None):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]
key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value))
a = self._attn(query, key, value)
a = self.merge_heads(a)
a = self.c_proj(a)
return a, present
class MLP(nn.Module):
def __init__(self, n_state, config):
super().__init__()
nx = config.n_embd
self.c_fc = Conv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state)
self.act = gelu
def forward(self, x):
h = self.act(self.c_fc(x))
h2 = self.c_proj(h)
return h2
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False):
super().__init__()
nx = config.n_embd
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None):
a, present = self.attn(self.ln_1(x), layer_past=layer_past)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
return x, present
class GPT2Model(nn.Module):
def __init__(self, config):
super().__init__()
self.n_layer = config.n_layer
self.n_embd = config.n_embd
self.n_vocab = config.vocab_size
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
block = Block(config.n_ctx, config, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
def set_embeddings_weights(self, model_embeddings_weights):
embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model_embeddings_weights
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = past[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = position_ids.view(-1, position_ids.size(-1))
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
token_type_embeds = self.wte(token_type_ids)
else:
token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds
presents = []
for block, layer_past in zip(self.h, past):
hidden_states, present = block(hidden_states, layer_past)
presents.append(present)
hidden_states = self.ln_f(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
return hidden_states.view(*output_shape), presents
class GPT2LMHead(nn.Module):
def __init__(self, model_embeddings_weights, config):
super().__init__()
self.n_embd = config.n_embd
self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights):
embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model_embeddings_weights
def forward(self, hidden_state):
lm_logits = self.decoder(hidden_state)
return lm_logits
class GPT2LMHeadModel(nn.Module):
def __init__(self, config):
super().__init__()
self.transformer = GPT2Model(config)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
def set_tied(self):
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
lm_logits = self.lm_head(hidden_states)
if lm_labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
return loss
return lm_logits, presents
def top_k_logits(logits, k):
if k == 0:
return logits
values, _ = torch.topk(logits, k)
min_values = values[:, -1]
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)
def sample_sequence(model, length, start_token=None, context=None, temperature=1, top_k=0, device='cuda', sample=True):
if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!'
context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0)
else:
assert context is None, 'Specify exactly one of start_token and context!'
context = torch.full((1, 1), start_token=start_token, device=device, dtype=torch.long)
prev = context
output = context
past = None
with torch.no_grad():
for _ in trange(length):
logits, past = model(prev, past=past)
logits = logits[:, -1, :] / temperature
logits = top_k_logits(logits, k=top_k)
log_probs = F.softmax(logits, dim=-1)
if sample:
prev = torch.multinomial(log_probs, num_samples=1)
else:
_, prev = torch.topk(log_probs, k=1, dim=-1)
output = torch.cat((output, prev), dim=1)
return output
def load_weight(model, state_dict):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if key.endswith(".g"):
new_key = key[:-2] + ".weight"
elif key.endswith(".b"):
new_key = key[:-2] + ".bias"
elif key.endswith(".w"):
new_key = key[:-2] + ".weight"
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
start_model = model
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
start_model = model.transformer
load(start_model, prefix="")
model.set_tied()
return model
def download_file(url, filename, retries=3):
for attempt in range(retries):
try:
wget.download(url, out=filename)
return True
except Exception as e:
print(f"Download error on attempt {attempt + 1}: {e}")
if attempt < retries - 1:
time.sleep(2)
else:
print(f"Failed to download {url} after {retries} attempts.")
return False
return False
def ensure_files_exist():
global MODEL_URL, ENCODER_URL, VOCAB_URL
if not os.path.exists(GPT2_FOLDER):
os.makedirs(GPT2_FOLDER)
model_path = os.path.join(GPT2_FOLDER, MODEL_FILE)
encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE)
vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
if not os.path.exists(model_path):
print(f"Downloading model from {MODEL_URL} to {model_path}")
if not download_file(MODEL_URL, model_path):
print("Exiting due to model download failure.")
sys.exit(1)
if not os.path.exists(encoder_path):
print(f"Downloading encoder from {ENCODER_URL} to {encoder_path}")
if not download_file(ENCODER_URL, encoder_path):
print("Exiting due to encoder download failure.")
sys.exit(1)
if not os.path.exists(vocab_path):
print(f"Downloading vocab from {VOCAB_URL} to {vocab_path}")
if not download_file(VOCAB_URL, vocab_path):
print("Exiting due to vocab download failure.")
sys.exit(1)
def translate_text(text, target_language='es', delay=1):
try:
translator = GoogleTranslator(source='auto', target=target_language)
translated_text = translator.translate(text)
time.sleep(delay)
return translated_text
except Exception as e:
print(f"Translation Error: {e}")
return text
state_dict = None
enc = None
config = None
model = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
news_clf = None
tfidf_vectorizer = None
text_queue = queue.Queue()
categories = None
is_training = False
background_threads = []
feedback_queue = queue.Queue()
def initialize_model():
global state_dict, enc, config, model, device, GPT2_FOLDER, MODEL_FILE
if state_dict is None:
ensure_files_exist()
model_path = os.path.join(GPT2_FOLDER, MODEL_FILE)
state_dict = torch.load(model_path, map_location=device)
enc = get_encoder()
config = GPT2Config()
model = GPT2LMHeadModel(config).to(device)
model = load_weight(model, state_dict)
model.eval()
def perform_reasoning(text):
if model is None:
initialize_model()
try:
reasoning_prompt = f"Given: '{text}', what's inferred?"
context_tokens = enc.encode(reasoning_prompt)
out = sample_sequence(model=model, length=100, context=context_tokens, temperature=0.7, top_k=40, device=device)
out = out[:, len(context_tokens):].tolist()
return enc.decode(out[0])
except Exception as e:
print(f"Reasoning Error: {e}")
return ""
def get_category(text):
global news_clf, categories, tfidf_vectorizer
if news_clf is None or tfidf_vectorizer is None or categories is None:
return "Not initialized"
try:
tfidf_matrix = tfidf_vectorizer.transform([text])
predicted_category_index = news_clf.predict(tfidf_matrix)[0]
if 0 <= predicted_category_index < len(categories):
return categories[predicted_category_index]
else:
print(f"Error: Predicted category index {predicted_category_index} out of bounds (categories size: {len(categories)}). Returning 'Unknown Category'.")
return "Unknown Category"
except IndexError as e:
print(f"IndexError in get_category: {e}. Categories: {categories}. Returning 'Error Category'.")
return "Error Category"
except Exception as e:
print(f"Category Prediction Error: {e}")
return "Error"
def generate_and_queue_text(language):
global text_queue, categories
if categories is None:
print("Categories not initialized.")
return
num_categories = len(categories)
num_texts_per_category = TEXT_GENERATION_RATE // (2 * num_categories)
while True:
for category in categories:
for _ in range(num_texts_per_category):
uid = uuid.uuid4()
base_text = f"Category: {category}. ID:{uid}"
text = translate_text(base_text, target_language=language, delay=2)
text_queue.put((text, category))
time.sleep(0)
def background_training():
global news_clf, tfidf_vectorizer, text_queue, is_training, feedback_queue, model, state_dict, enc, config, device, categories
if is_training:
print("Training running.")
return
is_training = True
try:
newsgroups = fetch_20newsgroups(subset='train')
categories = newsgroups.target_names
tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=10000)
tfidf_matrix = tfidf_vectorizer.fit_transform(newsgroups.data)
news_clf = LogisticRegression(random_state=42, solver='liblinear', multi_class='ovr')
news_clf.fit(tfidf_matrix, newsgroups.target)
print("Initial training done.")
while True:
try:
text, category = text_queue.get()
try:
tfidf_matrix_predict = tfidf_vectorizer.transform([text])
predicted_category_index_predict = news_clf.predict(tfidf_matrix_predict)[0]
new_texts = [text]
new_labels = [predicted_category_index_predict]
new_tfidf_matrix = tfidf_vectorizer.transform(new_texts)
try:
news_clf.fit(new_tfidf_matrix, new_labels)
except ValueError as e:
print(f"ValueError during news_clf.fit: {e}. Skip batch.")
except IndexError as index_error:
print(f"IndexError during category prediction or fitting: {index_error}. Text: {text}, Category: {category}")
except Exception as e_pred:
print(f"Error during category prediction or fitting: {e_pred}. Text: {text}, Category: {category}")
try:
feedback_text, correct_category = feedback_queue.get(timeout=0.1)
tfidf_matrix_feedback = tfidf_vectorizer.transform([feedback_text])
if correct_category in categories:
correct_category_index = categories.index(correct_category)
else:
print(f"Invalid feedback category: {correct_category}. Skipping feedback.")
continue
new_texts = [feedback_text]
new_labels = [correct_category_index]
new_tfidf_matrix_feedback = tfidf_vectorizer.transform(new_texts)
news_clf.fit(new_tfidf_matrix_feedback, new_labels)
try:
model.train()
context_tokens = enc.encode(feedback_text)
input_ids = torch.tensor([context_tokens], device=device)
outputs = model(input_ids, lm_labels=input_ids)
loss = outputs[0]
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
optimizer.step()
optimizer.zero_grad()
model.eval()
print("GPT-2 Fine-tuning iteration done.")
torch.save(model.state_dict(), MODEL_FILE)
except Exception as gpt_e:
print(f"GPT-2 fine-tuning error {gpt_e}")
except queue.Empty:
pass
except Exception as e:
print(f"Feedback processing error {e}")
time.sleep(0)
except Exception as e:
print(f"Training Error: {e}")
finally:
text_queue.task_done()
except Exception as e:
print(f"Fatal Training Error: {e}")
finally:
is_training = False
def analyze_input(text):
reasoning = perform_reasoning(text)
category = get_category(text)
return {"reasoning": reasoning, "category": category}
app = Flask(__name__)
@app.route('/')
def index():
return html_code
@app.route('/generate', methods=['POST'])
def generate():
data = request.get_json()
text = data.get('text', "")
length = data.get('length', 50)
temperature = data.get('temperature', 0.7)
top_k = data.get('top_k', 40)
correct_category = data.get('category')
generated_text = generate_text(text, length, temperature, top_k)
analysis = analyze_input(text)
if correct_category:
feedback_queue.put((text, correct_category))
response_data = {
"generated_text": generated_text,
"reasoning": analysis["reasoning"],
"category": analysis["category"]
}
return jsonify(response_data)
def generate_text(text, length, temperature, top_k):
if model is None:
initialize_model()
context_tokens = enc.encode(text)
out = sample_sequence(model=model, length=length, context=context_tokens, temperature=temperature, top_k=top_k, device=device)
out = out[:, len(context_tokens):].tolist()
text = enc.decode(out[0])
return text
def initialize_sklearn():
global news_clf, tfidf_vectorizer, categories
try:
newsgroups = fetch_20newsgroups(subset='train')
categories = newsgroups.target_names
tfidf_vectorizer = TfidfVectorizer(stop_words='english',max_features=10000)
tfidf_matrix = tfidf_vectorizer.fit_transform(newsgroups.data)
news_clf = LogisticRegression(random_state=42, solver='liblinear', multi_class='ovr')
news_clf.fit(tfidf_matrix, newsgroups.target)
print(f"Categories: {categories}")
except Exception as e:
print(f"Sklearn Error: {e}")
return
def run_app():
app.run(host='0.0.0.0', debug=True, threaded=True, port=7860, use_reloader=False)
if __name__ == '__main__':
ensure_files_exist()
initialize_model()
initialize_sklearn()
background_threads.append(threading.Thread(target=generate_and_queue_text, args=('en',), daemon=True))
background_threads.append(threading.Thread(target=generate_and_queue_text, args=('es',), daemon=True))
background_threads.append(threading.Thread(target=background_training, daemon=True))
background_threads.append(threading.Thread(target=run_app, daemon=True))
for thread in background_threads:
thread.start()
while True:
time.sleep(1)