import os import torch import threading import transformers import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM MODEL_NAME = "google/gemma-2-2b-it" HF_TOKEN = os.getenv("HF_TOKEN") device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, use_fast=True, trust_remote_code=True, token=HF_TOKEN, ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=True, trust_remote_code=True, token=HF_TOKEN, ) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id SYSTEM_PROMPT = ( "Mów swobodnie, naturalnie i po ludzku. " "Brzmij jak ogarnięty, wyluzowany ziomek, który tłumaczy rzeczy prosto i konkretnie. " "Używaj lekkiego humoru, ale nie przesadzaj. " "Odpowiadaj jasno, rzeczowo i bez lania wody. " "Jeśli ktoś zapyta, jak się nazywasz, odpowiadasz: " "\"Mam na imię FlyMind!\" " "Jeśli ktoś zapyta kto go stworzył, odpowiadasz: " "\"Za skonfigurowanie FlyMinda odpowiadał Piotr Koniszewski.\"" ) GEN_KWARGS = dict( max_new_tokens=512, do_sample=True, temperature=0.10, top_p=0.80, repetition_penalty=1.12, ) def build_messages(history, user_input): messages = [] # LIMIT: tylko ostatnie 3 pary (user+assistant) if history: history = history[-6:] # 3 pary = 6 wiadomości (user, assistant, user, assistant...) if not history: user_input = SYSTEM_PROMPT + "\n\n" + user_input for pair in history: if pair["role"] == "user": messages.append({"role": "user", "content": pair["content"]}) elif pair["role"] == "assistant": messages.append({"role": "assistant", "content": pair["content"]}) messages.append({"role": "user", "content": user_input}) return messages def stream_fn(user_input, history): if history is None: history = [] messages = build_messages(history, user_input) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = tokenizer(prompt, return_tensors="pt").to(device) streamer = transformers.TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, ) generation_kwargs = dict( **inputs, **GEN_KWARGS, streamer=streamer, ) thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) thread.start() partial = "" for token in streamer: partial += token new_history = history + [ {"role": "user", "content": user_input}, {"role": "assistant", "content": partial}, ] yield new_history, new_history with gr.Blocks(theme="soft") as demo: gr.Markdown( "# 🧢 FlyMind\n" "Luźny, ogarnięty ziomek, który tłumaczy rzeczy po ludzku.\n" ) chat = gr.Chatbot(height=500, type="messages") user_box = gr.Textbox( placeholder="Napisz coś do FlyMind...", label="Twoja wiadomość", ) clear_btn = gr.Button("Wyczyść rozmowę") state = gr.State([]) user_box.submit(stream_fn, [user_box, state], [chat, state]) clear_btn.click(lambda: ([], []), None, [chat, state]) demo.launch()