File size: 3,007 Bytes
869b3d2
bd901c8
e4866f8
869b3d2
e4866f8
869b3d2
e4866f8
 
 
 
 
 
 
 
 
 
 
 
bd901c8
 
e4866f8
 
 
35d6f94
bd901c8
e4866f8
869b3d2
e4866f8
869b3d2
e4866f8
869b3d2
bd901c8
e4866f8
 
bd901c8
869b3d2
bd901c8
869b3d2
 
 
 
e4866f8
 
 
 
 
869b3d2
e4866f8
 
 
 
bd901c8
 
e4866f8
 
bd901c8
e4866f8
 
 
 
 
 
 
bd901c8
e4866f8
 
 
 
 
 
 
 
 
 
869b3d2
 
 
 
 
 
bd901c8
869b3d2
 
bd901c8
869b3d2
 
bd901c8
869b3d2
bd901c8
 
e4866f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import streamlit as st
from PIL import Image
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import os

# -----------------------
# Streamlit config
# -----------------------
st.set_page_config(page_title="Multimodal Image Understanding AI", layout="centered")
st.title("πŸ“Έ Multimodal Image Understanding & Storytelling AI")
st.markdown(
    "Upload an image or use live camera, and get:\n"
    "- Caption\n"
    "- Summary\n"
    "- Detected objects\n"
    "- Emotion/mood\n"
    "- Short story inspired by the image"
)

# -----------------------
# Model settings
# -----------------------
MODEL_NAME = "Salesforce/blip2-flan-t5-xl"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HF_TOKEN = os.getenv("HF_TOKEN")  # Add HF_TOKEN as secret in Spaces (recommended)

@st.cache_resource(show_spinner="πŸ”„ Loading AI model, please wait...")
def load_model():
    processor = Blip2Processor.from_pretrained(MODEL_NAME, use_fast=False, token=HF_TOKEN)
    model = Blip2ForConditionalGeneration.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16 if DEVICE=="cuda" else torch.float32,
        device_map="auto" if DEVICE=="cuda" else None,
        token=HF_TOKEN
    )
    model.eval()
    return processor, model

processor, model = load_model()

# -----------------------
# Image input
# -----------------------
image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
camera_image = st.camera_input("Or take a live picture")

image = None
if camera_image:
    image = Image.open(camera_image).convert("RGB")
elif image_file:
    image = Image.open(image_file).convert("RGB")

if image:
    st.image(image, caption="Your Image", use_column_width=True)

    # -----------------------
    # Helper function
    # -----------------------
    def ask_model(prompt):
        inputs = processor(images=image, text=prompt, return_tensors="pt").to(DEVICE)
        out = model.generate(**inputs, max_new_tokens=150)
        return processor.decode(out[0], skip_special_tokens=True)

    with st.spinner("🧠 Analyzing image..."):
        caption = ask_model("Describe this image in one factual sentence.")
        summary = ask_model("Give a concise 3–5 line descriptive summary of this image.")
        objects = ask_model("List the main objects and entities visible in this image.")
        emotion = ask_model("Detect the emotional tone or mood of this image (happy, calm, tense, etc.).")
        story = ask_model("Write a short story (5–10 lines) inspired by this image.")

    # -----------------------
    # Output
    # -----------------------
    st.subheader("πŸ“ Caption")
    st.write(caption)

    st.subheader("πŸ“„ Summary")
    st.write(summary)

    st.subheader("πŸ“¦ Detected Objects")
    st.write(objects)

    st.subheader("😊 Emotional Tone")
    st.write(emotion)

    st.subheader("πŸ“– Short Story")
    st.write(story)

else:
    st.info("⬆️ Upload an image or use the camera above to begin.")