| import torch | |
| from torch.utils.data import Dataset | |
| from transformers import pipeline | |
| import streamlit as st | |
| import requests | |
| def get_story(image_path): | |
| model_name = st.selectbox('Select the Model', ['alpaca-lora', 'flan-t5-base']) | |
| image_to_text = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
| caption = image_to_text(image_path) | |
| caption = caption[0]['generated_text'] | |
| st.write(f"Generated Caption: {caption}") | |
| input_string = f"""Question: Generate 100 words story on this text | |
| '{caption}' Answer:""" | |
| if model_name == 'flan-t5-base': | |
| from transformers import T5ForConditionalGeneration, AutoTokenizer | |
| model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", device_map="auto", load_in_8bit=True) | |
| tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") | |
| inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cpu") | |
| outputs = model.generate(inputs, max_length=1000) | |
| outputs = tokenizer.decode(outputs[0]) | |
| else: | |
| response = requests.post("https://tloen-alpaca-lora.hf.space/run/predict", json={ | |
| "data": [ | |
| "Write a story about this image caption", | |
| caption, | |
| 0.1, | |
| 0.75, | |
| 40, | |
| 4, | |
| 128, | |
| ] | |
| }).json() | |
| data = response["data"] | |
| outputs = data[0] | |
| return outputs | |