0-roleplay / app.py
Rorical's picture
Update app.py
ff751bf verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from threading import Thread
import spaces
import time
# Global variables for model and tokenizer
model = None
tokenizer = None
def load_model():
"""Load the model and tokenizer - must be GPU decorated for Spaces"""
global model, tokenizer
if model is None:
print("Loading model...")
# Configure quantization (optional, remove if not needed)
# nf4_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_use_double_quant=True,
# bnb_4bit_compute_dtype=torch.bfloat16
# )
# Load model - adjust model name and settings as needed
model = AutoModelForCausalLM.from_pretrained(
"Rorical/0-roleplay-20240814-AWQ",
device_map="auto", # Changed from "cuda" to "auto"
# attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
# quantization_config=nf4_config, # Uncomment if using quantization
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Rorical/0-roleplay-20240814-AWQ"
)
# Set custom chat template
tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + ((message['role'] + '\n') if message['role'] != '' else '') + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
print("Model loaded successfully!")
def chat_response(message, history, character_name="ηžη’ƒ", max_tokens=512, temperature=0.7, top_p=0.9):
"""Generate chat response using the loaded model"""
global model, tokenizer
# Load model if not already loaded
if model is None or tokenizer is None:
load_model()
# Convert history to the expected format for the model
messages = []
for msg in history:
if msg['role'] == 'user':
messages.append({"role": "user", "content": msg['content']})
elif msg['role'] == 'assistant':
messages.append({"role": character_name, "content": msg['content']})
# Add current message
messages.append({"role": "user", "content": message})
try:
# Format chat using tokenizer template
formatted_chat = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
formatted_chat += f"<|im_start|>{character_name}\n"
# Tokenize input
inputs = tokenizer(formatted_chat, return_tensors="pt", add_special_tokens=False)
# Move to the same device as model
inputs = inputs.to(model.device)
# Generate response
with torch.no_grad():
start_time = time.time()
generate_ids = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
use_cache=True,
)
end_time = time.time()
# Decode response
response_text = tokenizer.decode(
generate_ids[0, inputs['input_ids'].size(1):],
skip_special_tokens=True
).strip()
print(f"Generation time: {end_time - start_time:.2f}s")
return response_text
except Exception as e:
print(f"Error during generation: {e}")
return f"Sorry, I encountered an error: {str(e)}"
def chat_response_streaming(message, history, character_name="ηžη’ƒ", max_tokens=512, temperature=0.7, top_p=0.9):
"""Generate streaming chat response"""
global model, tokenizer
# Load model if not already loaded
if model is None or tokenizer is None:
load_model()
# Convert history to the expected format for the model
messages = []
for msg in history:
if msg['role'] == 'user':
messages.append({"role": "user", "content": msg['content']})
elif msg['role'] == 'assistant':
messages.append({"role": character_name, "content": msg['content']})
# Add current message
messages.append({"role": "user", "content": message})
try:
# Format chat using tokenizer template
formatted_chat = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
formatted_chat += f"<|im_start|>{character_name}\n"
# Tokenize input
inputs = tokenizer(formatted_chat, return_tensors="pt", add_special_tokens=False)
# Move to the same device as model
inputs = inputs.to(model.device)
# Set up streaming
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
use_cache=True,
)
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream the response
partial_response = ""
for new_text in streamer:
if new_text:
partial_response += new_text
yield partial_response
thread.join()
except Exception as e:
print(f"Error during streaming generation: {e}")
yield f"Sorry, I encountered an error: {str(e)}"
def create_interface():
"""Create the Gradio interface"""
# Custom CSS for better styling
css = """
.gradio-container {
max-width: 1200px !important;
margin: auto !important;
}
.chat-message {
font-size: 16px !important;
}
"""
with gr.Blocks(css=css, title="AI Roleplay Chatbot", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ€– AI Roleplay Chatbot")
gr.Markdown("Chat with an AI character using advanced language modeling.")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Chat",
height=500,
show_label=False,
container=True,
elem_classes=["chat-message"],
type='messages'
)
with gr.Row():
msg = gr.Textbox(
label="Message",
placeholder="Type your message here...",
lines=2,
max_lines=4,
show_label=False,
container=False
)
send_btn = gr.Button("Send", variant="primary", size="lg")
with gr.Row():
clear_btn = gr.Button("Clear Chat", variant="secondary")
retry_btn = gr.Button("Retry Last", variant="secondary")
with gr.Column(scale=1):
gr.Markdown("### Settings")
character_name = gr.Textbox(
label="Character Name",
value="ηžη’ƒ",
placeholder="Enter character name..."
)
max_tokens = gr.Slider(
label="Max Tokens",
minimum=50,
maximum=1024,
value=512,
step=50
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1
)
top_p = gr.Slider(
label="Top P",
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05
)
streaming = gr.Checkbox(
label="Enable Streaming",
value=True
)
# Event handlers
@spaces.GPU(duration=30)
def respond(message, chat_history, char_name, max_tok, temp, top_p_val, use_streaming):
if use_streaming:
# For streaming response - add user message first
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": ""})
for partial_response in chat_response_streaming(
message, chat_history[:-2], char_name, max_tok, temp, top_p_val
):
chat_history[-1]["content"] = partial_response
yield chat_history, ""
else:
# For non-streaming response
response = chat_response(message, chat_history, char_name, max_tok, temp, top_p_val)
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": response})
yield chat_history, ""
def clear_chat():
return [], ""
def retry_last(chat_history, char_name, max_tok, temp, top_p_val, use_streaming):
if not chat_history:
return chat_history, ""
# Find the last user message
last_user_message = None
for i in range(len(chat_history) - 1, -1, -1):
if chat_history[i]['role'] == 'user':
last_user_message = chat_history[i]['content']
# Remove the last user message and any assistant responses after it
chat_history = chat_history[:i]
break
if last_user_message is None:
return chat_history, ""
if use_streaming:
chat_history.append({"role": "user", "content": last_user_message})
chat_history.append({"role": "assistant", "content": ""})
for partial_response in chat_response_streaming(
last_user_message, chat_history[:-2], char_name, max_tok, temp, top_p_val
):
chat_history[-1]["content"] = partial_response
yield chat_history, ""
else:
response = chat_response(last_user_message, chat_history, char_name, max_tok, temp, top_p_val)
chat_history.append({"role": "user", "content": last_user_message})
chat_history.append({"role": "assistant", "content": response})
yield chat_history, ""
# Connect events
msg.submit(
respond,
inputs=[msg, chatbot, character_name, max_tokens, temperature, top_p, streaming],
outputs=[chatbot, msg]
)
send_btn.click(
respond,
inputs=[msg, chatbot, character_name, max_tokens, temperature, top_p, streaming],
outputs=[chatbot, msg]
)
clear_btn.click(clear_chat, outputs=[chatbot, msg])
retry_btn.click(
retry_last,
inputs=[chatbot, character_name, max_tokens, temperature, top_p, streaming],
outputs=[chatbot, msg]
)
return demo
demo = create_interface()
demo.launch(
show_error=True
)