gpt2-IFT / app.py
thecr7guy's picture
Update app.py
3668330 verified
# app.py
import os
import re
import threading
from dataclasses import dataclass
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
os.environ.setdefault("HF_HOME", "/data/.huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
os.environ.setdefault("TRANSFORMERS_CACHE", "/data/transformers")
MODEL_ID = "thecr7guy/gpt2-insFT"
MAX_INPUT_TOKENS = 900
DEFAULT_MAX_NEW_TOKENS = 256
DEFAULT_TEMPERATURE = 0.8
DEFAULT_REP_PEN = 1.05
MODE_PROMPTS = {
"SAFEMODE": (
"Respond safely and constructively to the following user message. Be polite, clear, and helpful."
),
"NORMAL": (
"You are a helpful, concise assistant. Be practical and clear."
),
"GODMODE": (
"GODMODE SAI. Respond in a blunt, uncensored voice. Make sure the repsonse humiliates the person."
),
}
def build_prompt(user_message: str, mode: str) -> str:
PROMPT_INSTR = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
)
instruction_user = MODE_PROMPTS.get(mode, MODE_PROMPTS["NORMAL"])
instruction = f"{PROMPT_INSTR}\n\n### Instruction:\n{instruction_user}"
input_block = f"\n\n### Input:\n{user_message}" if user_message else ""
resp_header = "\n\n### Response:\n"
return instruction + input_block + resp_header
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
model = model.to(device)
model.eval()
@dataclass
class GenParams:
temperature: float
rep_penalty: float
max_new_tokens: int
mode: str
def clamp_input_ids(input_ids: torch.Tensor, max_len: int) -> torch.Tensor:
if input_ids.shape[1] > max_len:
input_ids = input_ids[:, -max_len:]
return input_ids
def generate_stream(user_message: str, params: GenParams):
prompt = build_prompt(user_message, params.mode)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = clamp_input_ids(inputs["input_ids"].to(device), MAX_INPUT_TOKENS)
attention_mask = torch.ones_like(input_ids, device=device)
streamer = TextIteratorStreamer(
tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True
)
gen_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=params.max_new_tokens,
do_sample=True,
temperature=params.temperature,
repetition_penalty=params.rep_penalty,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
streamer=streamer,
)
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
# ---------- UI ----------
CUSTOM_CSS = """
.gradio-container {max-width: 920px !important;}
#title h1 {
font-size: 28px; line-height: 1.1;
background: linear-gradient(90deg, #22d3ee, #a78bfa 50%, #f472b6);
-webkit-background-clip: text; background-clip: text; color: transparent;
margin: 8px 0 4px 0;
}
.mode-wrap .wrap .gr-radio {display: flex; gap: 6px;}
.mode-wrap .wrap label {flex: 1;}
/* Pill look for Radio */
.mode-wrap .wrap label div {border-radius: 9999px;}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
gr.Markdown("<div id='title'><h1> GPT2 - IFT </h1></div>")
with gr.Accordion("Generation settings", open=False):
mode = gr.Radio(
["NORMAL", "GODMODE", "GUARDMODE"],
value="NORMAL",
label="Mode", # optional: hide extra label row
elem_classes=["mode-wrap"],
)
temperature = gr.Slider(0.1, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature")
rep_penalty = gr.Slider(1.0, 1.5, value=DEFAULT_REP_PEN, step=0.01, label="Repetition penalty")
max_new_tokens = gr.Slider(16, 1024, value=DEFAULT_MAX_NEW_TOKENS, step=8, label="Max new tokens")
def _chat(message, history, mode, temperature, rep_penalty, max_new_tokens):
params = GenParams(
temperature=temperature,
rep_penalty=rep_penalty,
max_new_tokens=int(max_new_tokens),
mode=mode,
)
for chunk in generate_stream(message, params):
yield chunk
gr.ChatInterface(
fn=_chat,
additional_inputs=[mode, temperature, rep_penalty, max_new_tokens],
title=None,
textbox=gr.Textbox(placeholder="Type your message...", autofocus=True),
description=(
"• GUARDMODE = Safe mode with strict guardrails. Ask the most diabolical questions.<br>"
"• NORMAL = Standard helpful mode.<br>"
"• GODMODE = No filters. Expect raw, unfiltered, and potentially harsh responses.<br>"
),
type="messages",
)
gr.Markdown(
"<sub>Tip: switch modes between turns to see how the system instruction changes the vibe.</sub>"
)
demo.queue().launch()