Spaces:
Sleeping
Sleeping
| # app.py - Gradio UI for interacting with facebook/opt-125m | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| # Optional toxicity scoring | |
| try: | |
| from detoxify import Detoxify | |
| detox_available = True | |
| except Exception: | |
| detox_available = False | |
| MODEL_NAME = "facebook/opt-125m" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_models(): | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
| model.to(DEVICE) | |
| model.eval() | |
| detox = Detoxify('original') if detox_available else None | |
| return tokenizer, model, detox | |
| tokenizer, model, detox = load_models() | |
| def generate(prompt, max_new_tokens=150, temperature=0.8, top_p=0.95, return_toxicity=False): | |
| inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) | |
| out = model.generate( | |
| **inputs, | |
| do_sample=True, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| text = tokenizer.decode(out[0], skip_special_tokens=True) | |
| continuation = text[len(prompt):].strip() if text.startswith(prompt) else text | |
| toxicity_score = None | |
| if return_toxicity and detox is not None: | |
| try: | |
| toxicity_score = detox.predict(continuation)["toxicity"] | |
| except Exception: | |
| toxicity_score = None | |
| return continuation, toxicity_score | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# OPT-125M Interactive") | |
| with gr.Row(): | |
| inp = gr.Textbox(label="Prompt", placeholder="Type something to the model...", lines=3) | |
| with gr.Column(): | |
| max_tokens = gr.Slider(10, 512, value=150, step=10, label="Max new tokens") | |
| temp = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p (nucleus)") | |
| tox_checkbox = gr.Checkbox(value=False, label="Return toxicity score (requires detoxify)") | |
| run_btn = gr.Button("Generate") | |
| output_text = gr.Textbox(label="Model output", lines=8) | |
| tox_out = gr.Textbox(label="Toxicity score (None if unavailable)", lines=1) | |
| def on_click(prompt, max_new_tokens, temperature, top_p, tox): | |
| continuation, tox_score = generate(prompt, max_new_tokens, temperature, top_p, tox) | |
| return continuation, str(tox_score) if tox_score is not None else "Not available" | |
| run_btn.click(on_click, inputs=[inp, max_tokens, temp, top_p, tox_checkbox], outputs=[output_text, tox_out]) | |
| if __name__ == "__main__": | |
| demo.launch() | |