Spaces:
Runtime error
Runtime error
| from transformers import AutoTokenizer, AutoModelForCausalLM,pipeline | |
| import torch | |
| import streamlit as st | |
| MODELS={ | |
| 'uribe':{ | |
| 'tokenizer':AutoTokenizer.from_pretrained("jhonparra18/uribe-twitter-assistant-30ep"), | |
| 'model':AutoModelForCausalLM.from_pretrained("jhonparra18/uribe-twitter-assistant-30ep")}, | |
| 'petro':{ | |
| 'tokenizer':AutoTokenizer.from_pretrained("jhonparra18/petro-twitter-assistant-30ep-large"), | |
| 'model':AutoModelForCausalLM.from_pretrained("jhonparra18/petro-twitter-assistant-30ep-large")}} | |
| def callback_input_text(new_text): | |
| del st.session_state.input_user_txt | |
| st.session_state.input_user_txt=new_text | |
| def text_completion(tokenizer,model,input_text:str,max_len:int=100): | |
| tokenizer.padding_side="left" ##start padding from left to right | |
| tokenizer.pad_token = tokenizer.eos_token | |
| input_ids = tokenizer([input_text], return_tensors="pt",truncation=True,max_length=128) | |
| with torch.no_grad(): ##maybe useless as the generate method does not compute gradients, just in case | |
| outputs = model.generate(**input_ids, do_sample=True, max_length=max_len,top_k=100,top_p=0.95) | |
| out_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
| return out_text | |
| st.markdown("<h3 style='text-align: center; color: gray;'> 🐦 Tweet de Pol铆tico Colombiano: Autocompletado/generaci贸n de texto a partir de GPT2</h3>", unsafe_allow_html=True) | |
| st.text("") | |
| st.markdown("<h5 style='text-align: center; color: gray;'>Causal Language Modeling, source code <a href='https://github.com/statscol/twitter-user-autocomplete-assistant'> here </a> </h5>", unsafe_allow_html=True) | |
| st.text("") | |
| col1,col2 = st.columns(2) | |
| with col1: | |
| with st.form("input_values"): | |
| politician = st.selectbox( | |
| "Selecciona el pol铆tico", | |
| ("Uribe", "Petro") | |
| ) | |
| st.text("") | |
| max_length_text=st.slider('Num Max Tokens', 50, 200, 100,10,key="user_max_length") | |
| st.text("") | |
| input_user_text=st.empty() | |
| input_text_value=input_user_text.text_area('Input Text', 'Mi gobierno no es corrupto',key="input_user_txt",height=300) | |
| st.text("") | |
| complete_input=st.checkbox("Complete Input [Experimental]",value=False,help="Autom谩ticamente rellenar el texto inicial con el resultado para una nueva iteraci贸n") | |
| go_button=st.form_submit_button('Generate') | |
| with col2: | |
| if go_button: ##avoid re running script | |
| with st.spinner('Generating Text...'): | |
| output_text=text_completion(MODELS[politician.lower()]['tokenizer'],MODELS[politician.lower()]['model'],input_text_value,max_length_text) | |
| st.text_area("Tweet:",output_text,height=500,key="output_text") | |
| if complete_input: | |
| callback_input_text(output_text) | |
| input_user_text.text_area("Input Text", output_text,height=300) | |