File size: 3,498 Bytes
2134a77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8efbdfa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import time
from transformers import TextIteratorStreamer
from threading import Thread
import os
from transformers import AutoModelForImageTextToText, QuantoConfig
from PIL import Image
import io
import requests
from transformers import AutoProcessor, AutoModelForImageTextToText
#import torch
import streamlit as st


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def reduce_image_size(img, scale_percent=50):
    """Reduces the image size by a specified percentage."""
    width, height = img.size
    new_width = int(width * scale_percent / 100)
    new_height = int(height * scale_percent / 100)
    resized_img = img.resize((new_width, new_height))
    return resized_img


def model_inference(
    user_prompt, chat_history, max_new_tokens, images
):
    """Performs model inference using the provided inputs."""
    user_prompt = {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": user_prompt},
        ],
    }
    chat_history.append(user_prompt)
    streamer = TextIteratorStreamer(
        processor.tokenizer, skip_prompt=True, timeout=5.0
    )

    generation_args = {
        "max_new_tokens": max_new_tokens,
        "streamer": streamer,
        "do_sample": False,
    }

    prompt = processor.apply_chat_template(chat_history, add_generation_prompt=True)
    inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
    generation_args.update(inputs)

    thread = Thread(target=model.generate, kwargs=generation_args)
    thread.start()

    acc_text = ""
    for text_token in streamer:
        time.sleep(0.04)
        acc_text += text_token
        if acc_text.endswith("<end_of_utterance>"):
            acc_text = acc_text[:-18]
            yield acc_text

    thread.join()

def main():
    """Main function of the Streamlit app."""
    st.title("Text and Image Input App")

    # Load the model and processor outside the loop (once)
    global model, processor
    if "model" not in st.session_state:
        model_id = "HuggingFaceM4/idefics2-8b"
        quantization_config = QuantoConfig(weights="int8")
        processor = AutoProcessor.from_pretrained(model_id)
        model = AutoModelForImageTextToText.from_pretrained(
            model_id, device_map="cuda", quantization_config=quantization_config
        )
        st.session_state["model"] = model
        st.session_state["processor"] = processor

    model = st.session_state["model"]
    processor = st.session_state["processor"]

    # Get text input
    text_input = st.text_input("Enter your text:")

    # Get image input
    image_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
    if image_input is not None:
        image = Image.open(image_input)
        st.image(image, caption='Uploaded Image')
        processed_image = reduce_image_size(image)
    else:
        image_url = st.text_input("Enter image URL:")
        if image_url:
            response = requests.get(image_url)
            img = Image.open(io.BytesIO(response.content))
            st.image(img, caption='Image from URL')
            processed_image = reduce_image_size(img)

    if st.button("Predict"):
        if text_input and processed_image:
            prediction = model_inference(
                user_prompt="And what is in this image?",
                chat_history=[],  # Initialize chat history here
                max_new_tokens=100,
                images=processed_image)