generativetext / appmin.py
smartdigitalnetworks's picture
Update appmin.py
fbe68d2 verified
Raw
History Blame
5.1 kB
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = "# Generative Text Chat"
MAX_NEW_TOKENS_LIMIT = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "4096"))
MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
dtype=torch.bfloat16,
)
model.eval()
@spaces.GPU
def _generate_on_gpu(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
) -> Iterator[str]:
input_ids = input_ids.to(model.device)
attention_mask = attention_mask.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pad_token_id": tokenizer.eos_token_id,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"num_beams": 1,
"repetition_penalty": repetition_penalty,
"disable_compile": True,
}
exception_holder: list[Exception] = []
def _generate() -> None:
try:
model.generate(**generate_kwargs)
except Exception as e: # noqa: BLE001
exception_holder.append(e)
thread = Thread(target=_generate)
thread.start()
chunks: list[str] = []
for text in streamer:
chunks.append(text)
yield "".join(chunks)
thread.join()
if exception_holder:
msg = f"Generation failed: {exception_holder[0]}"
raise gr.Error(msg)
def validate_input(message: str) -> dict:
return gr.validate(bool(message and message.strip()), "Please enter a message.")
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
for msg in chat_history:
content = msg["content"]
if isinstance(content, list):
text = "".join(part["text"] for part in content if part.get("type") == "text")
else:
text = content
conversation.append({"role": msg["role"], "content": text})
conversation.append({"role": "user", "content": message})
inputs = tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, return_tensors="pt", return_dict=True
)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
n_input_tokens = input_ids.shape[1]
if n_input_tokens > MAX_INPUT_TOKENS:
err_msg = f"Input too long ({n_input_tokens} tokens). Maximum is {MAX_INPUT_TOKENS} tokens."
raise gr.Error(err_msg)
max_new_tokens = min(max_new_tokens, MAX_INPUT_TOKENS - n_input_tokens)
if max_new_tokens <= 0:
raise gr.Error("Input uses the entire context window. No room to generate new tokens.")
yield from _generate_on_gpu(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
)
demo = gr.ChatInterface(
fn=generate,
validator=validate_input,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_NEW_TOKENS_LIMIT,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=False,
examples=[
["Hello there! How are you doing?"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
cache_examples=False,
description=DESCRIPTION,
fill_height=True,
)
if __name__ == "__main__":
demo.launch(css_paths="style.css")