Spaces:
Runtime error
Runtime error
| import json | |
| import math | |
| import random | |
| import os | |
| import streamlit as st | |
| import transformers | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained("BenBranyon/hotrss-mistral-full") | |
| model = AutoModelForCausalLM.from_pretrained("BenBranyon/hotrss-mistral-full") | |
| st.set_page_config(page_title="House of the Red Solar Sky") | |
| st.markdown( | |
| """ | |
| <style> | |
| #house-of-the-red-solar-sky { | |
| text-align: center; | |
| } | |
| .stApp { | |
| background-image: url('https://f4.bcbits.com/img/a1824579252_16.jpg'); | |
| background-repeat: no-repeat; | |
| background-size: cover; | |
| background-blend-mode: hard-light; | |
| } | |
| .st-emotion-cache-1avcm0n { | |
| background: none; | |
| } | |
| .st-emotion-cache-1wmy9hl { | |
| } | |
| .st-emotion-cache-183lzff { | |
| overflow-x: unset; | |
| text-wrap: pretty; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.title("House of the Red Solar Sky") | |
| st.markdown( | |
| """ | |
| <style> | |
| .aligncenter { | |
| text-align: center; | |
| } | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| def post_process(output_sequences): | |
| predictions = [] | |
| generated_sequences = [] | |
| max_repeat = 2 | |
| # decode prediction | |
| for generated_sequence_idx, generated_sequence in enumerate(output_sequences): | |
| generated_sequence = generated_sequence.tolist() | |
| text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True) | |
| generated_sequences.append(text.strip()) | |
| for i, g in enumerate(generated_sequences): | |
| res = str(g).replace('\n\n\n', '\n').replace('\n\n', '\n') | |
| lines = res.split('\n') | |
| # print(lines) | |
| # i = max_repeat | |
| # while i != len(lines): | |
| # remove_count = 0 | |
| # for index in range(0, max_repeat): | |
| # # print(i - index - 1, i - index) | |
| # if lines[i - index - 1] == lines[i - index]: | |
| # remove_count += 1 | |
| # if remove_count == max_repeat: | |
| # lines.pop(i) | |
| # i -= 1 | |
| # else: | |
| # i += 1 | |
| predictions.append('\n'.join(lines)) | |
| return predictions | |
| start = st.text_input("Beginning of the song:", "Rap like a Sasquath in the trees") | |
| if st.button("Run"): | |
| if model is not None: | |
| with st.spinner(text=f"Generating lyrics..."): | |
| encoded_prompt = tokenizer(start, add_special_tokens=False, return_tensors="pt").input_ids | |
| encoded_prompt = encoded_prompt.to(model.device) | |
| # prediction | |
| output_sequences = model.generate( | |
| input_ids=encoded_prompt, | |
| max_length=160, | |
| min_length=100, | |
| temperature=float(1.00), | |
| top_p=float(0.95), | |
| top_k=int(50), | |
| do_sample=True, | |
| repetition_penalty=1.0, | |
| num_return_sequences=1 | |
| ) | |
| # Post-processing | |
| predictions = post_process(output_sequences) | |
| st.subheader("Results") | |
| for prediction in predictions: | |
| st.text(prediction) |