chat-3 / app.py
metastable-void
workaround
30790ee
#!/usr/bin/env python
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from peft import PeftModel
from transformers import (AutoModelForCausalLM, AutoTokenizer,
TextIteratorStreamer, pipeline)
DESCRIPTION = "# 真空ジェネレータ (v3)\n<p>Imitate 真空 (@vericava)'s posts interactively</p>"
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
MAX_MAX_NEW_TOKENS = 128
DEFAULT_MAX_NEW_TOKENS = 64
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))
if torch.cuda.is_available():
my_pipeline=pipeline(
task="text-generation",
model="vericava/llm-jp-3-vericava-posts-v1",
do_sample=True,
num_beams=1,
)
@spaces.GPU
@torch.inference_mode()
def generate(
message: str,
chat_history,
max_new_tokens: int = 1024,
temperature: float = 0.7,
top_p: float = 0.95,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
user_input = " ".join(message.strip().split("\n"))
user_input = user_input if (
user_input.endswith("。")
or user_input.endswith("?")
or user_input.endswith("!")
or user_input.endswith("?")
or user_input.endswith("!")
) else user_input + "。"
output = my_pipeline(
user_input,
temperature=temperature * 1.0,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty * 1.0,
top_k=top_k,
top_p=top_p * 1.0,
)[-1]["generated_text"]
print(output)
gen_text = output[len(user_input):]
#gen_text = gen_text[:gen_text.find("\n")] if "\n" in gen_text else gen_text
#gen_text = gen_text[:(gen_text.rfind("。") + 1)] if "。" in gen_text else gen_text
yield gen_text
demo = gr.ChatInterface(
fn=generate,
type="messages",
additional_inputs_accordion=gr.Accordion(label="詳細設定", open=False),
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=1.0,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.90,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=100,
step=1,
value=20,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=4.0,
step=0.05,
value=2.0,
),
],
stop_btn=None,
examples=[
["おはよ"],
["えらいね"],
["にゃん"],
["よしよし"],
],
description=DESCRIPTION,
css_paths="style.css",
fill_height=True,
)
if __name__ == "__main__":
demo.launch()