File size: 3,177 Bytes
f0dff07
 
 
9ab7a40
f0dff07
 
 
 
 
59be267
30790ee
 
f0dff07
45c9a85
f0dff07
 
 
 
45c9a85
 
f0dff07
 
245e479
f0dff07
97befb1
59be267
45c9a85
5a02dd0
 
97befb1
f0dff07
 
 
 
 
4a30925
f0dff07
 
 
 
 
 
e38ab6b
 
 
 
 
 
 
 
 
f0dff07
97befb1
e38ab6b
30790ee
0e35e11
30790ee
e38ab6b
30790ee
4a30925
9d7e24a
e38ab6b
45c9a85
 
e38ab6b
f0dff07
920b6db
f0dff07
4a30925
920b6db
f0dff07
 
 
 
 
 
 
 
 
 
 
 
 
0e35e11
f0dff07
 
 
 
 
 
45c9a85
f0dff07
 
 
 
45c9a85
f0dff07
45c9a85
f0dff07
 
 
 
45c9a85
f0dff07
45c9a85
f0dff07
 
 
 
45c9a85
 
97befb1
45c9a85
f0dff07
920b6db
 
 
f0dff07
 
 
920b6db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python

import os
from collections.abc import Iterator
from threading import Thread

import gradio as gr
import spaces
import torch
from peft import PeftModel
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          TextIteratorStreamer, pipeline)

DESCRIPTION = "# 真空ジェネレータ (v3)\n<p>Imitate 真空 (@vericava)'s posts interactively</p>"

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"

MAX_MAX_NEW_TOKENS = 128
DEFAULT_MAX_NEW_TOKENS = 64
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))


if torch.cuda.is_available():
    my_pipeline=pipeline(
        task="text-generation",
        model="vericava/llm-jp-3-vericava-posts-v1",
        do_sample=True,
        num_beams=1,
    )

@spaces.GPU
@torch.inference_mode()
def generate(
    message: str,
    chat_history,
    max_new_tokens: int = 1024,
    temperature: float = 0.7,
    top_p: float = 0.95,
    top_k: int = 50,
    repetition_penalty: float = 1.0,
) -> Iterator[str]:
    user_input = " ".join(message.strip().split("\n"))
    
    user_input = user_input if (
        user_input.endswith("。")
        or user_input.endswith("?")
        or user_input.endswith("!")
        or user_input.endswith("?")
        or user_input.endswith("!")
    ) else user_input + "。"

    output = my_pipeline(
        user_input,
        temperature=temperature * 1.0,
        max_new_tokens=max_new_tokens,
        repetition_penalty=repetition_penalty * 1.0,
        top_k=top_k,
        top_p=top_p * 1.0,
    )[-1]["generated_text"]
    print(output)
    gen_text = output[len(user_input):]
    #gen_text = gen_text[:gen_text.find("\n")] if "\n" in gen_text else gen_text
    #gen_text = gen_text[:(gen_text.rfind("。") + 1)] if "。" in gen_text else gen_text
    yield gen_text

demo = gr.ChatInterface(
    fn=generate,
    type="messages",
    additional_inputs_accordion=gr.Accordion(label="詳細設定", open=False),
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=1.0,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.90,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=100,
            step=1,
            value=20,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=4.0,
            step=0.05,
            value=2.0,
        ),
    ],
    stop_btn=None,
    examples=[
        ["おはよ"],
        ["えらいね"],
        ["にゃん"],
        ["よしよし"],
    ],
    description=DESCRIPTION,
    css_paths="style.css",
    fill_height=True,
)

if __name__ == "__main__":
    demo.launch()