Om2005Prakash's picture
Upload 7 files
6205838 verified
import os
# CRITICAL: Must be set before importing torch
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["MKL_NUM_THREADS"] = "2"
os.environ["TORCH_NUM_THREADS"] = "2"
import argparse
import torch
from transformers import AutoModelForMaskedLM
from tokenizer import get_tokenizer
from torch.quantization import quantize_dynamic
import gradio as gr
import logging
import random
import torch.quantization
# Re-using logic from inference.py
from inference import prepare_conditional_tokens_for_inference, prepare_unconditional_tokens_for_inference, clean_text
def seed_everything(seed: int):
import random, os
import numpy as np
import torch
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
logging.getLogger("transformers").setLevel(logging.ERROR)
def load_model_and_tokenizer(path_to_weights, hf_model_name, device="cpu"):
logger.info(f"Loading tokenizer: {hf_model_name}")
tokenizer = get_tokenizer(hf_model_name)
logger.info(f"Loading base model: {hf_model_name}")
# 1. Load the standard FP32 model structure
model = AutoModelForMaskedLM.from_pretrained(hf_model_name, device_map="cpu")
model.resize_token_embeddings(len(tokenizer))
logger.info(f"Loading QAT-trained weights from {path_to_weights}")
# 2. Load the QAT weights
# strict=False is REQUIRED because the checkpoint contains extra "observer" keys
# (stats about quantization) that the standard model doesn't need.
state_dict = torch.load(path_to_weights, map_location="cpu")
print(model.load_state_dict(state_dict, strict=False))
model.eval()
# 3. Apply Dynamic Quantization
# This creates a model that accepts FP32 inputs (solving your error)
# but uses INT8 weights internally (speeding it up).
logger.info("Applying dynamic quantization to QAT weights...")
model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear},
dtype=torch.qint8
)
# Move to CPU (Dynamic Quantization runs on CPU)
model.to(device)
return model, tokenizer
@torch.inference_mode()
def inference_stream(model, tokenizer, num_steps, strategy, device, prompt, seq_len, seed, corrector_step_size):
# Set seed
seed_everything(seed)
# Prepare tokens
if prompt:
input_tokens, mask, attention_mask = prepare_conditional_tokens_for_inference(
seq_len, tokenizer, prompt, device=device
)
else:
input_tokens, mask, attention_mask = prepare_unconditional_tokens_for_inference(
seq_len, tokenizer.mask_token_id, device=device
)
original_mask = mask.clone()
times = torch.linspace(1, 0, num_steps + 1, device=device)
for t, s in zip(times[:-1], times[1:]):
# Model forward
logits = model(input_tokens, attention_mask=attention_mask).logits
if strategy == "backward":
probs = torch.softmax(logits[mask], dim=-1)
input_tokens[mask] = torch.multinomial(probs, num_samples=1).squeeze(-1)
remask_probs = torch.rand_like(mask, dtype=torch.float, device=device)
remask_probs = (remask_probs < s/t)
mask = mask & remask_probs
input_tokens[mask] = tokenizer.mask_token_id
elif strategy == "predictor_corrector":
# Predictor
probs = torch.softmax(logits[mask], dim=-1)
input_tokens[mask] = torch.multinomial(probs, num_samples=1).squeeze(-1)
remask_probs = torch.rand_like(mask, dtype=torch.float, device=device)
remask_decision = (remask_probs < s/t)
mask = mask & remask_decision
input_tokens[mask] = tokenizer.mask_token_id
# Corrector
n_corrector_steps = 1
corrector_step_size_val = corrector_step_size * (t - s) / (1 - s)
# corrector_step_size_val = corrector_step_size
if n_corrector_steps > 0:
for _ in range(n_corrector_steps):
known_mask = ~mask ^ ~original_mask
noise_rng = torch.rand_like(known_mask, dtype=torch.float, device=device)
to_remask = known_mask & (noise_rng < corrector_step_size_val)
input_tokens[to_remask] = tokenizer.mask_token_id
corr_logits = model(input_tokens, attention_mask=attention_mask).logits
corr_probs = torch.softmax(corr_logits[to_remask], dim=-1)
input_tokens[to_remask] = torch.multinomial(corr_probs, num_samples=1).squeeze(-1)
# Decode for streaming
decoded_tokens = tokenizer.convert_ids_to_tokens(input_tokens[0])
cleaned_tokens = []
for tok in decoded_tokens:
if tok == tokenizer.mask_token:
cleaned_tokens.append(tok)
elif tok in tokenizer.all_special_tokens:
continue
else:
cleaned_tokens.append(tok)
decoded_after = tokenizer.convert_tokens_to_string(cleaned_tokens)
if prompt:
# Remove prompt for cleaner display
assistant_text = decoded_after.replace(prompt, "").strip()
# Clean artifacts
assistant_text = clean_text(assistant_text)
if not assistant_text:
# If cleaning removed everything, fallback to raw or partial
# This handles cases where prompt replacement might result in empty string (e.g. only masks that got cleaned? unlikely)
# or if the model just output the prompt.
# We yield decoded_after to show *something*
yield decoded_after
else:
yield assistant_text
else:
yield decoded_after
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Global model cache to avoid reloading
MODEL_CACHE = {}
def run_app(safetensors_path, hf_model_name, device, prompt, seq_len, num_steps, strategy, use_manual_seed, seed, corrector_step_size):
logger.info(f"Starting run_app with prompt: '{prompt}', device: {device}, strategy: {strategy}")
# Handle random seed
if not use_manual_seed:
seed = random.randint(0, 2**32 - 1)
logger.info(f"Using seed: {seed}")
# Load model if needed
cache_key = (safetensors_path, hf_model_name, device)
if cache_key not in MODEL_CACHE:
try:
logger.info(f"Loading model from {safetensors_path}...")
model, tokenizer = load_model_and_tokenizer(safetensors_path, hf_model_name, device)
MODEL_CACHE[cache_key] = (model, tokenizer)
logger.info("Model loaded successfully.")
except Exception as e:
logger.error(f"Error loading model: {str(e)}", exc_info=True)
yield f"Error loading model: {str(e)}"
return
model, tokenizer = MODEL_CACHE[cache_key]
# Run inference generator
try:
# Yield status info first (optional UX improvement)
yield f"Using Seed: {seed}..."
logger.info("Starting inference stream...")
for output in inference_stream(model, tokenizer, num_steps, strategy, device, prompt, seq_len, seed, corrector_step_size):
# logger.debug(f"Stream output: {output}") # excessive logging?
yield output
logger.info("Inference stream finished.")
except Exception as e:
logger.error(f"Error during inference: {str(e)}", exc_info=True)
yield f"Error during inference: {str(e)}"
# Custom handler to capture logs for UI
class ListLogHandler(logging.Handler):
def __init__(self):
super().__init__()
self.log_records = []
def emit(self, record):
log_entry = self.format(record)
self.log_records.append(log_entry)
def get_logs(self):
return "\n".join(self.log_records)
# Setup custom handler
list_handler = ListLogHandler()
list_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
list_handler.setFormatter(list_formatter)
logging.getLogger().addHandler(list_handler)
def get_logs_content():
return list_handler.get_logs()
# Gradio UI
css = """
.gradio-container {
background: linear-gradient(to right, #0f2027, #203a43, #2c5364);
}
"""
with gr.Blocks(title="Diffusion Quote Generator") as demo:
gr.Markdown("# Diffusion Quote Generator")
gr.Markdown("Generating text by iteratively removing noise (mask tokens).")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt", value="Generate a quote on hope")
gr.Markdown("### Quick Prompts")
selected_topics = gr.State(["Hope"])
topic_buttons = []
# Helper logic for toggling topics
def toggle_topic(topic, current_selected):
if topic in current_selected:
current_selected.remove(topic)
variant = "secondary"
else:
current_selected.append(topic)
variant = "primary"
# Dynamic prompt generation
if not current_selected:
new_prompt = "Generate a quote on hope"
else:
topics_str = ", ".join(current_selected)
new_prompt = f"Generate a quote on {topics_str.lower()}"
return new_prompt, current_selected, gr.Button(variant=variant)
with gr.Row():
topics = ["Hope", "Happiness", "Life", "Friendship", "Love", "Inspiration", "Suffering", "Faith"]
for topic in topics:
# check if active
default_variant = "primary" if topic == "Hope" else "secondary"
btn = gr.Button(topic, variant=default_variant)
topic_buttons.append(btn)
# Connect click event
# Inputs: current state. Outputs: prompt, state, button itself
btn.click(
fn=toggle_topic,
inputs=[gr.State(topic), selected_topics],
outputs=[prompt_input, selected_topics, btn]
)
with gr.Accordion("Advanced Settings", open=False):
safetensors_input = gr.Textbox(label="Safetensors Path", value="qat_model_unconverted.pt")
hf_model_input = gr.Textbox(label="Base HF Model", value="answerdotai/ModernBERT-base")
device_input = gr.Dropdown(label="Device", choices=["cuda", "cpu"], value="cpu")
strategy_input = gr.Dropdown(label="Strategy", choices=["predictor_corrector", "backward"], value="predictor_corrector")
seq_len_input = gr.Slider(label="Sequence Length", minimum=16, maximum=64, value=64, step=16)
num_steps_input = gr.Slider(label="Num Steps", minimum=2, maximum=128, value=32, step=2)
corrector_step_size_input = gr.Slider(label="Corrector Step Scale", minimum=0.0, maximum=5.0, value=2.0, step=0.01)
with gr.Row():
use_manual_seed = gr.Checkbox(label="Use Custom Seed", value=False)
seed_input = gr.Number(label="Seed", value=8734578, precision=0, visible=False)
# Toggle visibility of seed input
use_manual_seed.change(fn=lambda x: gr.Number(visible=x), inputs=use_manual_seed, outputs=seed_input)
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Generated Text", interactive=False, lines=10)
with gr.Accordion("Logs", open=False):
logs_output = gr.Code(label="Application Logs", language="markdown", lines=10)
refresh_logs_btn = gr.Button("Refresh Logs", size="sm")
# Auto-refresh logs every 2 seconds
timer = gr.Timer(2)
timer.tick(fn=get_logs_content, outputs=logs_output)
# Initial load
demo.load(fn=get_logs_content, outputs=logs_output)
refresh_logs_btn.click(fn=get_logs_content, outputs=logs_output)
run_btn.click(
fn=run_app,
inputs=[safetensors_input, hf_model_input, device_input, prompt_input, seq_len_input, num_steps_input, strategy_input, use_manual_seed, seed_input, corrector_step_size_input],
outputs=output_text
)
if __name__ == "__main__":
demo.queue().launch(css=css)