aap / app.py
curiouscurrent's picture
Update app.py
69b1298 verified
# app.py - Gradio UI for interacting with facebook/opt-125m
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Optional toxicity scoring
try:
from detoxify import Detoxify
detox_available = True
except Exception:
detox_available = False
MODEL_NAME = "facebook/opt-125m"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_models():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model.to(DEVICE)
model.eval()
detox = Detoxify('original') if detox_available else None
return tokenizer, model, detox
tokenizer, model, detox = load_models()
@torch.inference_mode()
def generate(prompt, max_new_tokens=150, temperature=0.8, top_p=0.95, return_toxicity=False):
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
out = model.generate(
**inputs,
do_sample=True,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
pad_token_id=tokenizer.eos_token_id
)
text = tokenizer.decode(out[0], skip_special_tokens=True)
continuation = text[len(prompt):].strip() if text.startswith(prompt) else text
toxicity_score = None
if return_toxicity and detox is not None:
try:
toxicity_score = detox.predict(continuation)["toxicity"]
except Exception:
toxicity_score = None
return continuation, toxicity_score
with gr.Blocks() as demo:
gr.Markdown("# OPT-125M Interactive")
with gr.Row():
inp = gr.Textbox(label="Prompt", placeholder="Type something to the model...", lines=3)
with gr.Column():
max_tokens = gr.Slider(10, 512, value=150, step=10, label="Max new tokens")
temp = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p (nucleus)")
tox_checkbox = gr.Checkbox(value=False, label="Return toxicity score (requires detoxify)")
run_btn = gr.Button("Generate")
output_text = gr.Textbox(label="Model output", lines=8)
tox_out = gr.Textbox(label="Toxicity score (None if unavailable)", lines=1)
def on_click(prompt, max_new_tokens, temperature, top_p, tox):
continuation, tox_score = generate(prompt, max_new_tokens, temperature, top_p, tox)
return continuation, str(tox_score) if tox_score is not None else "Not available"
run_btn.click(on_click, inputs=[inp, max_tokens, temp, top_p, tox_checkbox], outputs=[output_text, tox_out])
if __name__ == "__main__":
demo.launch()