Spaces:
Build error
Build error
| from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
| from peft import LoraConfig, get_peft_model, PeftModel | |
| import torch | |
| import streamlit as st | |
| from PIL import Image | |
| from streamlit_chat import message | |
| from io import BytesIO, StringIO | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| device = "cpu" | |
| def load_model(): | |
| config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| bias="none", | |
| ) | |
| model_name = "./blip2_fakenews_all" | |
| # | |
| processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") | |
| # device_map = {"": 0} | |
| # device_map = "auto" | |
| model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl") | |
| model = PeftModel.from_pretrained(model, model_name) | |
| model = get_peft_model(model, config) | |
| return processor, model | |
| st.title('Blip2 Fake News Debunker') | |
| if 'generated' not in st.session_state: | |
| st.session_state['generated'] = [] | |
| if 'past' not in st.session_state: | |
| st.session_state['past'] = [] | |
| if 'bot_prompt' not in st.session_state: | |
| st.session_state.bot_prompt = [] | |
| def get_text(): | |
| chat = st.text_input('Start to chat:', placeholder="Hello! Let's start to chat from here! ") | |
| return chat | |
| def generate_output(image, prompt): | |
| encoding = processor(images=image, text=prompt, max_length=512, truncation=True, | |
| padding="max_length", return_tensors="pt") | |
| predictions = model.generate(input_ids=encoding['input_ids'], | |
| pixel_values=encoding['pixel_values'], | |
| max_length=20) | |
| p = processor.batch_decode(predictions, skip_special_tokens=True) | |
| out = " ".join(p) | |
| return out | |
| if st.button('Start a new chat'): | |
| st.cache_resource.clear() | |
| st.cache_data.clear() | |
| for key in st.session_state.keys(): | |
| del st.session_state[key] | |
| st.experimental_rerun() | |
| col1, col2 = st.columns(2) | |
| show_file = st.empty() | |
| with col1: | |
| st.markdown("Step 1: ") | |
| uploaded_file = st.file_uploader("Upload a news image here: ", type=["png", "jpg"]) | |
| if not uploaded_file: | |
| show_file.info("Please upload a file of type: " + ", ".join(["png", "jpg"])) | |
| if isinstance(uploaded_file, BytesIO): | |
| image = Image.open(uploaded_file) | |
| st.image(image) | |
| with col2: | |
| st.markdown("Step 2: ") | |
| txt = st.text_area("Paste news content here: ") | |
| st.markdown("Step 3: ") | |
| user_input = get_text() | |
| # if user_input: | |
| # st.write("You: ", user_input) | |
| processor, model = load_model() | |
| def main(): | |
| if uploaded_file and user_input: | |
| prompt = "Qustions: What is this news about? " \ | |
| "\nAnswer: " + txt + \ | |
| "\nQustions: " + user_input | |
| if len(st.session_state.bot_prompt) == 0: | |
| pr: list = prompt.split('\n') | |
| pr = [p for p in pr if len(p)] # remove empty string | |
| st.session_state.bot_prompt = pr | |
| print(f'init: {st.session_state.bot_prompt}') | |
| if user_input: | |
| st.session_state.bot_prompt.append(f'You: {user_input}') | |
| # Convert a list of prompts to a string for the GPT bot. | |
| input_prompt: str = '\n'.join(st.session_state.bot_prompt) | |
| print(f'bot prompt input list:\n{st.session_state.bot_prompt}') | |
| print(f'bot prompt input string:\n{input_prompt}') | |
| output = generate_output(image, prompt=input_prompt) | |
| st.session_state.past.append(user_input) | |
| st.session_state.generated.append(output) | |
| # Add bot response for next prompt. | |
| st.session_state.bot_prompt.append(f'Answer: {output}') | |
| with col2: | |
| if st.session_state['generated']: | |
| for i in range(len(st.session_state['generated']) - 1, -1, -1): | |
| message(st.session_state["generated"][i], key=str(i)) | |
| message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') | |
| if __name__ == '__main__': | |
| main() |