| | import streamlit as st |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| | import torch |
| |
|
| | base_model_name = "chaseharmon/Rap-Mistral-Big" |
| |
|
| |
|
| | @st.cache_resource |
| | def load_model(): |
| | nf4_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_use_double_quant=False, |
| | bnb_4bit_compute_dtype="float16" |
| | ) |
| |
|
| | model = AutoModelForCausalLM.from_pretrained( |
| | base_model_name, |
| | device_map='auto', |
| | quantization_config=nf4_config, |
| | ) |
| | model.config.use_cache = False |
| | model.config.pretraining_tp = 1 |
| |
|
| | return model |
| |
|
| | @st.cache_resource |
| | def load_tokenizer(): |
| | tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
| |
|
| | tokenizer.pad_token = tokenizer.eos_token |
| | tokenizer.padding_side = "right" |
| |
|
| | return tokenizer |
| |
|
| | def build_prompt(question): |
| | prompt=f"[INST] {question} [/INST] " |
| | return prompt |
| |
|
| | model = load_model() |
| | model.eval() |
| |
|
| | tokenizer = load_tokenizer() |
| |
|
| | st.title("Rap Verse Generation V1 Demo") |
| | st.header("Supported Artists") |
| | st.write("Lupe Fiasco, Common, Jay-Z, Yasiin Bey, Ab-Soul, Rakim") |
| |
|
| | prompt_placeholder = st.empty() |
| | display_placeholder = st.empty() |
| |
|
| | prompt_placeholder.write("Ask the AI to write a verse") |
| | display_placeholder.write("") |
| |
|
| | question = st.chat_input("Write a verse in the style of Lupe Fiasco") |
| | if question: |
| | display_placeholder.write("Loading...") |
| | prompt_placeholder.write(question) |
| | prompt = build_prompt(question) |
| | inputs = tokenizer(prompt, return_tensors="pt") |
| | model_inputs = inputs.to('cuda') |
| | generated_ids = model.generate(**model_inputs, max_new_tokens=300, do_sample=True, pad_token_id=tokenizer.eos_token_id) |
| | response = tokenizer.batch_decode(generated_ids)[0] |
| | end_of_inst = response.find("[/INST]") + len("[/INST]") |
| | if end_of_inst > -1: |
| | actual_response = response[end_of_inst:].strip() |
| | else: |
| | actual_response = response |
| | actual_response = actual_response.replace("\n", " \n") |
| | actual_response = actual_response.replace("nigga", "brotha") |
| | actual_response = actual_response.replace("Nigga", "Brotha") |
| | actual_response = actual_response.replace("faggot", "f--got") |
| | display_placeholder.write(actual_response) |