| import json |
| import math |
| import random |
| import os |
| import streamlit as st |
| import transformers |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| tokenizer = AutoTokenizer.from_pretrained("KZMTx/redsolarsky") |
| model = AutoModelForCausalLM.from_pretrained("KZMTx/redsolarsky") |
|
|
| st.set_page_config(page_title="House of the Red Solar Sky") |
|
|
| st.markdown( |
| """ |
| <style> |
| #house-of-the-red-solar-sky { |
| text-align: center; |
| } |
| .stApp { |
| background-image: url('https://f4.bcbits.com/img/a1824579252_16.jpg'); |
| background-repeat: no-repeat; |
| background-size: cover; |
| background-blend-mode: hard-light; |
| } |
| .st-emotion-cache-1avcm0n { |
| background: none; |
| } |
| .st-emotion-cache-1wmy9hl { |
| } |
| .st-emotion-cache-183lzff { |
| overflow-x: unset; |
| text-wrap: pretty; |
| } |
| </style> |
| """, |
| unsafe_allow_html=True, |
| ) |
|
|
| st.title("House of the Red Solar Sky") |
|
|
| st.markdown( |
| """ |
| <style> |
| .aligncenter { |
| text-align: center; |
| } |
| """, |
| unsafe_allow_html=True, |
| ) |
|
|
| def post_process(output_sequences): |
| predictions = [] |
| generated_sequences = [] |
|
|
| max_repeat = 2 |
|
|
| |
| for generated_sequence_idx, generated_sequence in enumerate(output_sequences): |
| generated_sequence = generated_sequence.tolist() |
| text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True) |
| generated_sequences.append(text.strip()) |
| |
| for i, g in enumerate(generated_sequences): |
| res = str(g).replace('\n\n\n', '\n').replace('\n\n', '\n') |
| lines = res.split('\n') |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| predictions.append('\n'.join(lines)) |
|
|
| return predictions |
|
|
| start = st.text_input("Beginning of the song:", "Rap like a Sasquath in the trees") |
|
|
| if st.button("Run"): |
| if model is not None: |
| with st.spinner(text=f"Generating lyrics..."): |
| encoded_prompt = tokenizer(start, add_special_tokens=False, return_tensors="pt").input_ids |
| encoded_prompt = encoded_prompt.to(model.device) |
| |
| output_sequences = model.generate( |
| input_ids=encoded_prompt, |
| max_length=160, |
| min_length=100, |
| temperature=float(1.00), |
| top_p=float(0.95), |
| top_k=int(50), |
| do_sample=True, |
| repetition_penalty=1.0, |
| num_return_sequences=1 |
| ) |
| |
| predictions = post_process(output_sequences) |
| st.subheader("Results") |
| for prediction in predictions: |
| st.text(prediction) |
|
|
|
|