File size: 3,223 Bytes
56ccbf2 0d4961f 56ccbf2 3f1d37e 56ccbf2 0d4961f 56ccbf2 0d4961f 13f0e11 0d4961f 56ccbf2 cb13191 a7254c4 cfc4d9c a7254c4 cb13191 56ccbf2 0d4961f 56ccbf2 22c0038 56ccbf2 0d4961f 56ccbf2 13f0e11 56ccbf2 13f0e11 cb13191 13f0e11 56ccbf2 3f1d37e cb13191 3f1d37e cb13191 0d4961f 3f1d37e 0d4961f 3f1d37e 0d4961f cb13191 0d4961f 56ccbf2 0d4961f a7254c4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread
MODEL_NAMES = {
"dqnCode v0.2 1.5B": "DQN-Labs/dqnCode-v0.2-1.5B-HF",
"dqnCode v0.3 1.2B": "DQN-Labs/dqnCode-v0.3-1.2B-MLX-4bit",
}
model_cache = {}
def load_model(model_key):
if model_key in model_cache:
return model_cache[model_key]
model_name = MODEL_NAMES[model_key]
print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.float16 if device == "cuda" else torch.float32,
).to(device)
model_cache[model_key] = (tokenizer, model)
return tokenizer, model
def chat_with_model(message, history, model_choice):
tokenizer, model = load_model(model_choice)
device = model.device
prompt = "You are dqnCode, an intelligent and conversational AI assistant designed to help users with questions, problem-solving, and creative tasks. You communicate clearly, reason carefully, and explain your thoughts in an easy-to-understand way. Stay friendly, professional, and curious. If the user's request is ambiguous, ask clarifying questions before proceeding."
for msg in history:
role = msg["role"]
content = msg["content"]
prompt += f"{role.capitalize()}: {content}\n"
prompt += f"User: {message}\nAssistant:"
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=2048,
temperature=0.7,
top_p=0.9,
do_sample=True,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield history + [
{"role": "user", "content": message},
{"role": "assistant", "content": partial_text},
]
def create_demo():
with gr.Blocks(title="DQN Labs Chat") as demo:
gr.Markdown("## DQN Labs Chat")
model_choice = gr.Dropdown(
label="Select Model",
choices=list(MODEL_NAMES.keys()),
value="dqnCode v0.2 1.5B"
)
chatbot = gr.Chatbot(
label="Chat with DQN!",
type="messages",
height=450
)
msg = gr.Textbox(label="Your message", placeholder="Type away...")
clear = gr.Button("Clear")
def add_user_message(user_message, chat_history):
chat_history = chat_history + [{"role": "user", "content": user_message}]
return "", chat_history
msg.submit(add_user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
chat_with_model, [msg, chatbot, model_choice], chatbot
)
clear.click(lambda: [], None, chatbot, queue=False)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860) |