|
|
import os |
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "2" |
|
|
os.environ["MKL_NUM_THREADS"] = "2" |
|
|
os.environ["TORCH_NUM_THREADS"] = "2" |
|
|
|
|
|
import argparse |
|
|
import torch |
|
|
from transformers import AutoModelForMaskedLM |
|
|
from tokenizer import get_tokenizer |
|
|
from torch.quantization import quantize_dynamic |
|
|
import gradio as gr |
|
|
import logging |
|
|
import random |
|
|
import torch.quantization |
|
|
|
|
|
|
|
|
from inference import prepare_conditional_tokens_for_inference, prepare_unconditional_tokens_for_inference, clean_text |
|
|
|
|
|
def seed_everything(seed: int): |
|
|
import random, os |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
random.seed(seed) |
|
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
logging.getLogger("transformers").setLevel(logging.ERROR) |
|
|
|
|
|
def load_model_and_tokenizer(path_to_weights, hf_model_name, device="cpu"): |
|
|
logger.info(f"Loading tokenizer: {hf_model_name}") |
|
|
tokenizer = get_tokenizer(hf_model_name) |
|
|
|
|
|
logger.info(f"Loading base model: {hf_model_name}") |
|
|
|
|
|
model = AutoModelForMaskedLM.from_pretrained(hf_model_name, device_map="cpu") |
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
logger.info(f"Loading QAT-trained weights from {path_to_weights}") |
|
|
|
|
|
|
|
|
|
|
|
state_dict = torch.load(path_to_weights, map_location="cpu") |
|
|
print(model.load_state_dict(state_dict, strict=False)) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Applying dynamic quantization to QAT weights...") |
|
|
model = torch.quantization.quantize_dynamic( |
|
|
model, |
|
|
{torch.nn.Linear}, |
|
|
dtype=torch.qint8 |
|
|
) |
|
|
|
|
|
|
|
|
model.to(device) |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
@torch.inference_mode() |
|
|
def inference_stream(model, tokenizer, num_steps, strategy, device, prompt, seq_len, seed, corrector_step_size): |
|
|
|
|
|
seed_everything(seed) |
|
|
|
|
|
|
|
|
if prompt: |
|
|
input_tokens, mask, attention_mask = prepare_conditional_tokens_for_inference( |
|
|
seq_len, tokenizer, prompt, device=device |
|
|
) |
|
|
else: |
|
|
input_tokens, mask, attention_mask = prepare_unconditional_tokens_for_inference( |
|
|
seq_len, tokenizer.mask_token_id, device=device |
|
|
) |
|
|
|
|
|
original_mask = mask.clone() |
|
|
times = torch.linspace(1, 0, num_steps + 1, device=device) |
|
|
|
|
|
for t, s in zip(times[:-1], times[1:]): |
|
|
|
|
|
logits = model(input_tokens, attention_mask=attention_mask).logits |
|
|
|
|
|
if strategy == "backward": |
|
|
probs = torch.softmax(logits[mask], dim=-1) |
|
|
input_tokens[mask] = torch.multinomial(probs, num_samples=1).squeeze(-1) |
|
|
|
|
|
remask_probs = torch.rand_like(mask, dtype=torch.float, device=device) |
|
|
remask_probs = (remask_probs < s/t) |
|
|
mask = mask & remask_probs |
|
|
input_tokens[mask] = tokenizer.mask_token_id |
|
|
|
|
|
elif strategy == "predictor_corrector": |
|
|
|
|
|
probs = torch.softmax(logits[mask], dim=-1) |
|
|
input_tokens[mask] = torch.multinomial(probs, num_samples=1).squeeze(-1) |
|
|
|
|
|
remask_probs = torch.rand_like(mask, dtype=torch.float, device=device) |
|
|
remask_decision = (remask_probs < s/t) |
|
|
mask = mask & remask_decision |
|
|
input_tokens[mask] = tokenizer.mask_token_id |
|
|
|
|
|
|
|
|
n_corrector_steps = 1 |
|
|
corrector_step_size_val = corrector_step_size * (t - s) / (1 - s) |
|
|
|
|
|
|
|
|
if n_corrector_steps > 0: |
|
|
for _ in range(n_corrector_steps): |
|
|
known_mask = ~mask ^ ~original_mask |
|
|
noise_rng = torch.rand_like(known_mask, dtype=torch.float, device=device) |
|
|
to_remask = known_mask & (noise_rng < corrector_step_size_val) |
|
|
|
|
|
input_tokens[to_remask] = tokenizer.mask_token_id |
|
|
corr_logits = model(input_tokens, attention_mask=attention_mask).logits |
|
|
corr_probs = torch.softmax(corr_logits[to_remask], dim=-1) |
|
|
input_tokens[to_remask] = torch.multinomial(corr_probs, num_samples=1).squeeze(-1) |
|
|
|
|
|
|
|
|
decoded_tokens = tokenizer.convert_ids_to_tokens(input_tokens[0]) |
|
|
cleaned_tokens = [] |
|
|
for tok in decoded_tokens: |
|
|
if tok == tokenizer.mask_token: |
|
|
cleaned_tokens.append(tok) |
|
|
elif tok in tokenizer.all_special_tokens: |
|
|
continue |
|
|
else: |
|
|
cleaned_tokens.append(tok) |
|
|
decoded_after = tokenizer.convert_tokens_to_string(cleaned_tokens) |
|
|
|
|
|
if prompt: |
|
|
|
|
|
assistant_text = decoded_after.replace(prompt, "").strip() |
|
|
|
|
|
assistant_text = clean_text(assistant_text) |
|
|
|
|
|
if not assistant_text: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
yield decoded_after |
|
|
else: |
|
|
yield assistant_text |
|
|
else: |
|
|
yield decoded_after |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[logging.StreamHandler()] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
MODEL_CACHE = {} |
|
|
|
|
|
def run_app(safetensors_path, hf_model_name, device, prompt, seq_len, num_steps, strategy, use_manual_seed, seed, corrector_step_size): |
|
|
logger.info(f"Starting run_app with prompt: '{prompt}', device: {device}, strategy: {strategy}") |
|
|
|
|
|
|
|
|
if not use_manual_seed: |
|
|
seed = random.randint(0, 2**32 - 1) |
|
|
logger.info(f"Using seed: {seed}") |
|
|
|
|
|
|
|
|
cache_key = (safetensors_path, hf_model_name, device) |
|
|
if cache_key not in MODEL_CACHE: |
|
|
try: |
|
|
logger.info(f"Loading model from {safetensors_path}...") |
|
|
model, tokenizer = load_model_and_tokenizer(safetensors_path, hf_model_name, device) |
|
|
MODEL_CACHE[cache_key] = (model, tokenizer) |
|
|
logger.info("Model loaded successfully.") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model: {str(e)}", exc_info=True) |
|
|
yield f"Error loading model: {str(e)}" |
|
|
return |
|
|
|
|
|
model, tokenizer = MODEL_CACHE[cache_key] |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
yield f"Using Seed: {seed}..." |
|
|
|
|
|
logger.info("Starting inference stream...") |
|
|
for output in inference_stream(model, tokenizer, num_steps, strategy, device, prompt, seq_len, seed, corrector_step_size): |
|
|
|
|
|
yield output |
|
|
logger.info("Inference stream finished.") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during inference: {str(e)}", exc_info=True) |
|
|
yield f"Error during inference: {str(e)}" |
|
|
|
|
|
|
|
|
class ListLogHandler(logging.Handler): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.log_records = [] |
|
|
|
|
|
def emit(self, record): |
|
|
log_entry = self.format(record) |
|
|
self.log_records.append(log_entry) |
|
|
|
|
|
def get_logs(self): |
|
|
return "\n".join(self.log_records) |
|
|
|
|
|
|
|
|
list_handler = ListLogHandler() |
|
|
list_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
|
|
list_handler.setFormatter(list_formatter) |
|
|
logging.getLogger().addHandler(list_handler) |
|
|
|
|
|
def get_logs_content(): |
|
|
return list_handler.get_logs() |
|
|
|
|
|
|
|
|
css = """ |
|
|
.gradio-container { |
|
|
background: linear-gradient(to right, #0f2027, #203a43, #2c5364); |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(title="Diffusion Quote Generator") as demo: |
|
|
gr.Markdown("# Diffusion Quote Generator") |
|
|
gr.Markdown("Generating text by iteratively removing noise (mask tokens).") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt_input = gr.Textbox(label="Prompt", value="Generate a quote on hope") |
|
|
|
|
|
gr.Markdown("### Quick Prompts") |
|
|
selected_topics = gr.State(["Hope"]) |
|
|
topic_buttons = [] |
|
|
|
|
|
|
|
|
def toggle_topic(topic, current_selected): |
|
|
if topic in current_selected: |
|
|
current_selected.remove(topic) |
|
|
variant = "secondary" |
|
|
else: |
|
|
current_selected.append(topic) |
|
|
variant = "primary" |
|
|
|
|
|
|
|
|
if not current_selected: |
|
|
new_prompt = "Generate a quote on hope" |
|
|
else: |
|
|
topics_str = ", ".join(current_selected) |
|
|
new_prompt = f"Generate a quote on {topics_str.lower()}" |
|
|
|
|
|
return new_prompt, current_selected, gr.Button(variant=variant) |
|
|
|
|
|
with gr.Row(): |
|
|
topics = ["Hope", "Happiness", "Life", "Friendship", "Love", "Inspiration", "Suffering", "Faith"] |
|
|
for topic in topics: |
|
|
|
|
|
default_variant = "primary" if topic == "Hope" else "secondary" |
|
|
btn = gr.Button(topic, variant=default_variant) |
|
|
topic_buttons.append(btn) |
|
|
|
|
|
|
|
|
|
|
|
btn.click( |
|
|
fn=toggle_topic, |
|
|
inputs=[gr.State(topic), selected_topics], |
|
|
outputs=[prompt_input, selected_topics, btn] |
|
|
) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
safetensors_input = gr.Textbox(label="Safetensors Path", value="qat_model_unconverted.pt") |
|
|
hf_model_input = gr.Textbox(label="Base HF Model", value="answerdotai/ModernBERT-base") |
|
|
device_input = gr.Dropdown(label="Device", choices=["cuda", "cpu"], value="cpu") |
|
|
strategy_input = gr.Dropdown(label="Strategy", choices=["predictor_corrector", "backward"], value="predictor_corrector") |
|
|
seq_len_input = gr.Slider(label="Sequence Length", minimum=16, maximum=64, value=64, step=16) |
|
|
num_steps_input = gr.Slider(label="Num Steps", minimum=2, maximum=128, value=32, step=2) |
|
|
corrector_step_size_input = gr.Slider(label="Corrector Step Scale", minimum=0.0, maximum=5.0, value=2.0, step=0.01) |
|
|
|
|
|
with gr.Row(): |
|
|
use_manual_seed = gr.Checkbox(label="Use Custom Seed", value=False) |
|
|
seed_input = gr.Number(label="Seed", value=8734578, precision=0, visible=False) |
|
|
|
|
|
|
|
|
use_manual_seed.change(fn=lambda x: gr.Number(visible=x), inputs=use_manual_seed, outputs=seed_input) |
|
|
|
|
|
run_btn = gr.Button("Generate", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_text = gr.Textbox(label="Generated Text", interactive=False, lines=10) |
|
|
|
|
|
with gr.Accordion("Logs", open=False): |
|
|
logs_output = gr.Code(label="Application Logs", language="markdown", lines=10) |
|
|
refresh_logs_btn = gr.Button("Refresh Logs", size="sm") |
|
|
|
|
|
|
|
|
timer = gr.Timer(2) |
|
|
timer.tick(fn=get_logs_content, outputs=logs_output) |
|
|
|
|
|
|
|
|
demo.load(fn=get_logs_content, outputs=logs_output) |
|
|
refresh_logs_btn.click(fn=get_logs_content, outputs=logs_output) |
|
|
|
|
|
run_btn.click( |
|
|
fn=run_app, |
|
|
inputs=[safetensors_input, hf_model_input, device_input, prompt_input, seq_len_input, num_steps_input, strategy_input, use_manual_seed, seed_input, corrector_step_size_input], |
|
|
outputs=output_text |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch(css=css) |
|
|
|