image-text / app.py
IsraelSalgado's picture
Update app.py
2134a77 verified
import time
from transformers import TextIteratorStreamer
from threading import Thread
import os
from transformers import AutoModelForImageTextToText, QuantoConfig
from PIL import Image
import io
import requests
from transformers import AutoProcessor, AutoModelForImageTextToText
#import torch
import streamlit as st
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def reduce_image_size(img, scale_percent=50):
"""Reduces the image size by a specified percentage."""
width, height = img.size
new_width = int(width * scale_percent / 100)
new_height = int(height * scale_percent / 100)
resized_img = img.resize((new_width, new_height))
return resized_img
def model_inference(
user_prompt, chat_history, max_new_tokens, images
):
"""Performs model inference using the provided inputs."""
user_prompt = {
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": user_prompt},
],
}
chat_history.append(user_prompt)
streamer = TextIteratorStreamer(
processor.tokenizer, skip_prompt=True, timeout=5.0
)
generation_args = {
"max_new_tokens": max_new_tokens,
"streamer": streamer,
"do_sample": False,
}
prompt = processor.apply_chat_template(chat_history, add_generation_prompt=True)
inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
generation_args.update(inputs)
thread = Thread(target=model.generate, kwargs=generation_args)
thread.start()
acc_text = ""
for text_token in streamer:
time.sleep(0.04)
acc_text += text_token
if acc_text.endswith("<end_of_utterance>"):
acc_text = acc_text[:-18]
yield acc_text
thread.join()
def main():
"""Main function of the Streamlit app."""
st.title("Text and Image Input App")
# Load the model and processor outside the loop (once)
global model, processor
if "model" not in st.session_state:
model_id = "HuggingFaceM4/idefics2-8b"
quantization_config = QuantoConfig(weights="int8")
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
model_id, device_map="cuda", quantization_config=quantization_config
)
st.session_state["model"] = model
st.session_state["processor"] = processor
model = st.session_state["model"]
processor = st.session_state["processor"]
# Get text input
text_input = st.text_input("Enter your text:")
# Get image input
image_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if image_input is not None:
image = Image.open(image_input)
st.image(image, caption='Uploaded Image')
processed_image = reduce_image_size(image)
else:
image_url = st.text_input("Enter image URL:")
if image_url:
response = requests.get(image_url)
img = Image.open(io.BytesIO(response.content))
st.image(img, caption='Image from URL')
processed_image = reduce_image_size(img)
if st.button("Predict"):
if text_input and processed_image:
prediction = model_inference(
user_prompt="And what is in this image?",
chat_history=[], # Initialize chat history here
max_new_tokens=100,
images=processed_image)