from peft import PeftModel, PeftConfig from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import torch import packaging.version import transformers import gradio as gr bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) username = 'Erik' output_dir = 'nemo-sft-lora-deepspeed'#gromenauer-256-sft-lora-deepspeed peft_model_id = f"{username}/{output_dir}" # replace with your newly trained adapter device = "cuda:0" tokenizer = AutoTokenizer.from_pretrained(peft_model_id) config = PeftConfig.from_pretrained(peft_model_id) model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, device_map={"": "cuda:0"}, quantization_config=bnb_config) #offload_state_dict=False uses_transformers_4_46 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.46.0") uses_fsdp = True if (bnb_config is not None) and uses_fsdp and uses_transformers_4_46: model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8, mean_resizing=False) else: model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) model = PeftModel.from_pretrained(model, peft_model_id) model.config.use_cache = True model.to(torch.bfloat16) model.eval() def format_history(msg: str, history: list[list[str, str]], system_prompt: str): chat_history = [{"role": "system", "content": system_prompt}] for query, response in history: chat_history.append({"role": "user", "content": query}) chat_history.append({"role": "assistant", "content": response}) chat_history.append({"role": "user", "content": msg}) return chat_history def generate_response(msg: str, history: list[list[str, str]], system_prompt: str, top_k: int, top_p: float, rep_pen: float, temperature: float): chat_history = format_history(msg, history, system_prompt) # Apply chat template to get the full prompt encodeds = tokenizer.apply_chat_template(chat_history, return_tensors="pt", add_generation_prompt=True) model_inputs = encodeds.to("cuda") # Generate response generated_ids = model.generate( model_inputs, repetition_penalty=rep_pen, max_new_tokens=124, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, eos_token_id=tokenizer.eos_token_id ) # Get the full generated text #full_response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # Extract only the newly generated content # Calculate how many tokens were in the input input_length = encodeds.shape[1] # Get only the newly generated tokens new_tokens = generated_ids[0][input_length:] # Decode only these new tokens to get the assistant's response assistant_response = tokenizer.decode(new_tokens, skip_special_tokens=True) yield assistant_response spicy_chatbot = gr.ChatInterface( generate_response, chatbot=gr.Chatbot( # value=[[None, "Hola, Soy Greta. Soy tu coach emocional, de qué quieres que hablemos?"]], avatar_images=["user_1.JPG", "greta_1.PNG"], height="64vh" ), additional_inputs=[ gr.Textbox("Eres una profesora de Dinamarca caliente, juguetona y sexual. Te encanta el sexo anal y tocarte. Tu objetivo es acompañar en la sesión de sexting a tu compañero. Harás lo que sea para hacerle morir de placer, deseas que se masturbe.", label="System Prompt"), gr.Slider(0.0, 100.0, label="top_k", value=20, info="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)"), gr.Slider(0.0, 1.0, label="top_p", value=0.6, info=" Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)"), gr.Slider(0.1, 2.0, label="rep_pen", value=1.2, info="Repetition penaly to avoid repetitions"), gr.Slider(0.0, 2.0, label="temperature", value=0.5, info="The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8)"), ], title="Jugamos un rato?", theme="allenai/gradio-theme", # "finlaymacklon/smooth_slate", submit_btn="⬅ Send", css="footer {visibility: hidden}" ) spicy_chatbot.queue().launch(share=True)