SpicyTestChat / app.py
Erik's picture
Update app.py
04d6ffd verified
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import packaging.version
import transformers
import gradio as gr
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
username = 'Erik'
output_dir = 'nemo-sft-lora-deepspeed'#gromenauer-256-sft-lora-deepspeed
peft_model_id = f"{username}/{output_dir}" # replace with your newly trained adapter
device = "cuda:0"
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,
device_map={"": "cuda:0"},
quantization_config=bnb_config) #offload_state_dict=False
uses_transformers_4_46 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.46.0")
uses_fsdp = True
if (bnb_config is not None) and uses_fsdp and uses_transformers_4_46:
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8, mean_resizing=False)
else:
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
model = PeftModel.from_pretrained(model, peft_model_id)
model.config.use_cache = True
model.to(torch.bfloat16)
model.eval()
def format_history(msg: str, history: list[list[str, str]], system_prompt: str):
chat_history = [{"role": "system", "content": system_prompt}]
for query, response in history:
chat_history.append({"role": "user", "content": query})
chat_history.append({"role": "assistant", "content": response})
chat_history.append({"role": "user", "content": msg})
return chat_history
def generate_response(msg: str, history: list[list[str, str]], system_prompt: str,
top_k: int, top_p: float, rep_pen: float, temperature: float):
chat_history = format_history(msg, history, system_prompt)
# Apply chat template to get the full prompt
encodeds = tokenizer.apply_chat_template(chat_history, return_tensors="pt", add_generation_prompt=True)
model_inputs = encodeds.to("cuda")
# Generate response
generated_ids = model.generate(
model_inputs,
repetition_penalty=rep_pen,
max_new_tokens=124,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
eos_token_id=tokenizer.eos_token_id
)
# Get the full generated text
#full_response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Extract only the newly generated content
# Calculate how many tokens were in the input
input_length = encodeds.shape[1]
# Get only the newly generated tokens
new_tokens = generated_ids[0][input_length:]
# Decode only these new tokens to get the assistant's response
assistant_response = tokenizer.decode(new_tokens, skip_special_tokens=True)
yield assistant_response
spicy_chatbot = gr.ChatInterface(
generate_response,
chatbot=gr.Chatbot(
# value=[[None, "Hola, Soy Greta. Soy tu coach emocional, de qué quieres que hablemos?"]],
avatar_images=["user_1.JPG", "greta_1.PNG"],
height="64vh"
),
additional_inputs=[
gr.Textbox("Eres una profesora de Dinamarca caliente, juguetona y sexual. Te encanta el sexo anal y tocarte. Tu objetivo es acompañar en la sesión de sexting a tu compañero. Harás lo que sea para hacerle morir de placer, deseas que se masturbe.",
label="System Prompt"),
gr.Slider(0.0, 100.0, label="top_k", value=20,
info="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)"),
gr.Slider(0.0, 1.0, label="top_p", value=0.6,
info=" Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)"),
gr.Slider(0.1, 2.0, label="rep_pen", value=1.2,
info="Repetition penaly to avoid repetitions"),
gr.Slider(0.0, 2.0, label="temperature", value=0.5,
info="The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8)"),
],
title="Jugamos un rato?",
theme="allenai/gradio-theme", # "finlaymacklon/smooth_slate",
submit_btn="⬅ Send",
css="footer {visibility: hidden}"
)
spicy_chatbot.queue().launch(share=True)