Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, GPT2Tokenizer, StoppingCriteria, StoppingCriteriaList | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| import torch | |
| import random | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| PROJECT_MODEL = "RickMartel/GPT2_FT_By_NT_RAND_v11" | |
| model = AutoModelForCausalLM.from_pretrained(PROJECT_MODEL) | |
| model = model.to( device ) | |
| model.eval() | |
| tokenizer = GPT2Tokenizer.from_pretrained(PROJECT_MODEL) | |
| class StoppingCriteriaSub(StoppingCriteria): | |
| def __init__(self, stops = [], encounters=1): | |
| super().__init__() | |
| self.stops = [stop.to( device ) for stop in stops] | |
| self.encounters = encounters | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | |
| last_tkn = input_ids[0][-1] | |
| stop_word_found = False | |
| for stop in self.stops: | |
| if sum( input_ids[0] == stop ) >= self.encounters: | |
| stop_word_found = True | |
| return stop_word_found and self.stops[0] == last_tkn | |
| # The StoppingCriteriaSub assumes period is the first token id. | |
| stop_words = ['.'] | |
| stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words] | |
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, | |
| encounters=3)]) | |
| st.set_page_config(page_title="GPT2 4 Bible") | |
| st.sidebar.title("GPT2 4 Bible") | |
| st.sidebar.markdown( | |
| """ | |
| Model notes: | |
| - This is a fine-tuned Hugging Face distilgpt2 model. | |
| - The dataset used was the Christian New Testament. | |
| - This Space uses a CPU only. So, the app is slow. | |
| - This is a document completion model. Not a Q&A. Input prompts like, "Jesus said". | |
| """ | |
| ) | |
| form = st.form(key='my-form') | |
| txt = form.text_input('Enter a prompt') | |
| submit = form.form_submit_button('Submit') | |
| if submit: | |
| with st.spinner('Processing...'): | |
| st.markdown("<h4 style='text-align: left;'>Response:</h4>", unsafe_allow_html=True) | |
| ta = st.empty() | |
| input = tokenizer([tokenizer.bos_token + txt], return_tensors="pt") | |
| streamer = TextIteratorStreamer( tokenizer ) | |
| generation_kwargs = dict(input, streamer=streamer, | |
| stopping_criteria=stopping_criteria, | |
| do_sample=True, | |
| max_new_tokens=200,) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| generated_text = "" | |
| for new_text in streamer: | |
| generated_text += new_text.replace('"', "").replace(tokenizer.bos_token,"") | |
| ta.write( generated_text ) | |