CALM2-7B-chat / app.py
hayas's picture
Update
0b61b3a
#!/usr/bin/env python
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, TextIteratorStreamer
class StopOnSignal(StoppingCriteria):
def __init__(self) -> None:
self.stopped = False
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs: object) -> bool: # noqa: ARG002
return self.stopped
DESCRIPTION = "# CALM2-7B-chat"
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))
model_id = "cyberagent/calm2-7b-chat"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
def apply_chat_template(conversation: list[dict[str, str]]) -> str:
prompt = "\n".join([f"{c['role']}: {c['content']}" for c in conversation])
return f"{prompt}\nASSISTANT: "
@spaces.GPU
def _run_on_gpu(
input_ids: torch.Tensor,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
) -> Iterator[str]:
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
stop_criteria = StopOnSignal()
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"stopping_criteria": [stop_criteria],
"max_new_tokens": max_new_tokens,
"do_sample": True,
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"num_beams": 1,
"repetition_penalty": repetition_penalty,
}
exception_holder: list[Exception] = []
def _generate() -> None:
try:
model.generate(**generate_kwargs)
except Exception as e: # noqa: BLE001
exception_holder.append(e)
t = Thread(target=_generate)
t.start()
outputs: list[str] = []
try:
for text in streamer:
outputs.append(text)
yield "".join(outputs)
except GeneratorExit:
stop_criteria.stopped = True
for _ in streamer:
pass
t.join()
raise
t.join()
if exception_holder:
err_msg = f"Generation failed: {exception_holder[0]}"
raise gr.Error(err_msg)
def generate(
message: str,
chat_history: list[dict[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.7,
top_p: float = 0.95,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
conversation = []
for msg in chat_history:
role = "USER" if msg["role"] == "user" else "ASSISTANT"
if isinstance(msg["content"], list):
text = "".join(part["text"] for part in msg["content"] if part["type"] == "text")
else:
text = str(msg["content"])
conversation.append({"role": role, "content": text})
conversation.append({"role": "USER", "content": message})
prompt = apply_chat_template(conversation)
input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
yield from _run_on_gpu(input_ids, max_new_tokens, temperature, top_p, top_k, repetition_penalty)
demo = gr.ChatInterface(
fn=generate,
additional_inputs_accordion=gr.Accordion(label="詳細設定", open=False),
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.7,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.95,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
),
],
examples=[
["東京の観光名所を教えて。"],
["落武者って何?"], # noqa: RUF001
["暴れん坊将軍って誰のこと?"], # noqa: RUF001
["人がヘリを食べるのにかかる時間は?"], # noqa: RUF001
],
description=DESCRIPTION,
fill_height=True,
)
if __name__ == "__main__":
demo.launch(css_paths="style.css")