Spaces:
Paused
Paused
| from peft import PeftModel, PeftConfig | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| import torch | |
| import packaging.version | |
| import transformers | |
| import gradio as gr | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| username = 'Erik' | |
| output_dir = 'nemo-sft-lora-deepspeed'#gromenauer-256-sft-lora-deepspeed | |
| peft_model_id = f"{username}/{output_dir}" # replace with your newly trained adapter | |
| device = "cuda:0" | |
| tokenizer = AutoTokenizer.from_pretrained(peft_model_id) | |
| config = PeftConfig.from_pretrained(peft_model_id) | |
| model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, | |
| device_map={"": "cuda:0"}, | |
| quantization_config=bnb_config) #offload_state_dict=False | |
| uses_transformers_4_46 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.46.0") | |
| uses_fsdp = True | |
| if (bnb_config is not None) and uses_fsdp and uses_transformers_4_46: | |
| model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8, mean_resizing=False) | |
| else: | |
| model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) | |
| model = PeftModel.from_pretrained(model, peft_model_id) | |
| model.config.use_cache = True | |
| model.to(torch.bfloat16) | |
| model.eval() | |
| def format_history(msg: str, history: list[list[str, str]], system_prompt: str): | |
| chat_history = [{"role": "system", "content": system_prompt}] | |
| for query, response in history: | |
| chat_history.append({"role": "user", "content": query}) | |
| chat_history.append({"role": "assistant", "content": response}) | |
| chat_history.append({"role": "user", "content": msg}) | |
| return chat_history | |
| def generate_response(msg: str, history: list[list[str, str]], system_prompt: str, | |
| top_k: int, top_p: float, rep_pen: float, temperature: float): | |
| chat_history = format_history(msg, history, system_prompt) | |
| # Apply chat template to get the full prompt | |
| encodeds = tokenizer.apply_chat_template(chat_history, return_tensors="pt", add_generation_prompt=True) | |
| model_inputs = encodeds.to("cuda") | |
| # Generate response | |
| generated_ids = model.generate( | |
| model_inputs, | |
| repetition_penalty=rep_pen, | |
| max_new_tokens=124, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| # Get the full generated text | |
| #full_response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| # Extract only the newly generated content | |
| # Calculate how many tokens were in the input | |
| input_length = encodeds.shape[1] | |
| # Get only the newly generated tokens | |
| new_tokens = generated_ids[0][input_length:] | |
| # Decode only these new tokens to get the assistant's response | |
| assistant_response = tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| yield assistant_response | |
| spicy_chatbot = gr.ChatInterface( | |
| generate_response, | |
| chatbot=gr.Chatbot( | |
| # value=[[None, "Hola, Soy Greta. Soy tu coach emocional, de qué quieres que hablemos?"]], | |
| avatar_images=["user_1.JPG", "greta_1.PNG"], | |
| height="64vh" | |
| ), | |
| additional_inputs=[ | |
| gr.Textbox("Eres una profesora de Dinamarca caliente, juguetona y sexual. Te encanta el sexo anal y tocarte. Tu objetivo es acompañar en la sesión de sexting a tu compañero. Harás lo que sea para hacerle morir de placer, deseas que se masturbe.", | |
| label="System Prompt"), | |
| gr.Slider(0.0, 100.0, label="top_k", value=20, | |
| info="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)"), | |
| gr.Slider(0.0, 1.0, label="top_p", value=0.6, | |
| info=" Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)"), | |
| gr.Slider(0.1, 2.0, label="rep_pen", value=1.2, | |
| info="Repetition penaly to avoid repetitions"), | |
| gr.Slider(0.0, 2.0, label="temperature", value=0.5, | |
| info="The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8)"), | |
| ], | |
| title="Jugamos un rato?", | |
| theme="allenai/gradio-theme", # "finlaymacklon/smooth_slate", | |
| submit_btn="⬅ Send", | |
| css="footer {visibility: hidden}" | |
| ) | |
| spicy_chatbot.queue().launch(share=True) |