FlyMind2 / app.py
piotrek2137's picture
Update app.py
cd0c2f6 verified
import os
import torch
import threading
import transformers
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_NAME = "google/gemma-2-2b-it"
HF_TOKEN = os.getenv("HF_TOKEN")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
use_fast=True,
trust_remote_code=True,
token=HF_TOKEN,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
trust_remote_code=True,
token=HF_TOKEN,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
SYSTEM_PROMPT = (
"M贸w swobodnie, naturalnie i po ludzku. "
"Brzmij jak ogarni臋ty, wyluzowany ziomek, kt贸ry t艂umaczy rzeczy prosto i konkretnie. "
"U偶ywaj lekkiego humoru, ale nie przesadzaj. "
"Odpowiadaj jasno, rzeczowo i bez lania wody. "
"Je艣li kto艣 zapyta, jak si臋 nazywasz, odpowiadasz: "
"\"Mam na imi臋 FlyMind!\" "
"Je艣li kto艣 zapyta kto go stworzy艂, odpowiadasz: "
"\"Za skonfigurowanie FlyMinda odpowiada艂 Piotr Koniszewski.\""
)
GEN_KWARGS = dict(
max_new_tokens=512,
do_sample=True,
temperature=0.10,
top_p=0.80,
repetition_penalty=1.12,
)
def build_messages(history, user_input):
messages = []
# LIMIT: tylko ostatnie 3 pary (user+assistant)
if history:
history = history[-6:] # 3 pary = 6 wiadomo艣ci (user, assistant, user, assistant...)
if not history:
user_input = SYSTEM_PROMPT + "\n\n" + user_input
for pair in history:
if pair["role"] == "user":
messages.append({"role": "user", "content": pair["content"]})
elif pair["role"] == "assistant":
messages.append({"role": "assistant", "content": pair["content"]})
messages.append({"role": "user", "content": user_input})
return messages
def stream_fn(user_input, history):
if history is None:
history = []
messages = build_messages(history, user_input)
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = transformers.TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_kwargs = dict(
**inputs,
**GEN_KWARGS,
streamer=streamer,
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial = ""
for token in streamer:
partial += token
new_history = history + [
{"role": "user", "content": user_input},
{"role": "assistant", "content": partial},
]
yield new_history, new_history
with gr.Blocks(theme="soft") as demo:
gr.Markdown(
"# 馃Б FlyMind\n"
"Lu藕ny, ogarni臋ty ziomek, kt贸ry t艂umaczy rzeczy po ludzku.\n"
)
chat = gr.Chatbot(height=500, type="messages")
user_box = gr.Textbox(
placeholder="Napisz co艣 do FlyMind...",
label="Twoja wiadomo艣膰",
)
clear_btn = gr.Button("Wyczy艣膰 rozmow臋")
state = gr.State([])
user_box.submit(stream_fn, [user_box, state], [chat, state])
clear_btn.click(lambda: ([], []), None, [chat, state])
demo.launch()