|
|
import nest_asyncio
|
|
|
nest_asyncio.apply()
|
|
|
|
|
|
import streamlit as st
|
|
|
from transformers import (
|
|
|
VisionEncoderDecoderModel,
|
|
|
ViTImageProcessor,
|
|
|
AutoTokenizer,
|
|
|
BlipProcessor,
|
|
|
BlipForConditionalGeneration
|
|
|
)
|
|
|
import together
|
|
|
import torch
|
|
|
from PIL import Image
|
|
|
from dotenv import load_dotenv
|
|
|
import json
|
|
|
import logging
|
|
|
logging.getLogger("transformers").setLevel(logging.ERROR)
|
|
|
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
class ImprovedVisualChatbot:
|
|
|
def __init__(self):
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
|
|
self.blip_model = BlipForConditionalGeneration.from_pretrained(
|
|
|
"Salesforce/blip-image-captioning-large"
|
|
|
).to(self.device)
|
|
|
|
|
|
|
|
|
self.vit_gpt2_model = VisionEncoderDecoderModel.from_pretrained(
|
|
|
"nlpconnect/vit-gpt2-image-captioning"
|
|
|
).to(self.device)
|
|
|
self.vit_gpt2_feature_extractor = ViTImageProcessor.from_pretrained(
|
|
|
"nlpconnect/vit-gpt2-image-captioning"
|
|
|
)
|
|
|
self.vit_gpt2_tokenizer = AutoTokenizer.from_pretrained(
|
|
|
"nlpconnect/vit-gpt2-image-captioning"
|
|
|
)
|
|
|
|
|
|
|
|
|
if "messages" not in st.session_state:
|
|
|
st.session_state.messages = []
|
|
|
|
|
|
def get_blip_description(self, image: Image) -> str:
|
|
|
"""Get detailed image description using BLIP model"""
|
|
|
inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
|
|
|
|
|
|
|
|
|
outputs = self.blip_model.generate(
|
|
|
**inputs,
|
|
|
max_length=100,
|
|
|
num_beams=5,
|
|
|
temperature=1.0,
|
|
|
repetition_penalty=1.2,
|
|
|
length_penalty=1.0
|
|
|
)
|
|
|
|
|
|
return self.blip_processor.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
def get_vit_gpt2_description(self, image: Image) -> str:
|
|
|
"""Get additional perspective using ViT-GPT2 model"""
|
|
|
pixel_values = self.vit_gpt2_feature_extractor(
|
|
|
images=image, return_tensors="pt"
|
|
|
).pixel_values.to(self.device)
|
|
|
|
|
|
output_ids = self.vit_gpt2_model.generate(
|
|
|
pixel_values,
|
|
|
max_length=50,
|
|
|
num_beams=4,
|
|
|
temperature=0.8,
|
|
|
do_sample=True
|
|
|
)
|
|
|
|
|
|
return self.vit_gpt2_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
|
|
|
|
def get_visual_qa(self, image: Image, question: str) -> str:
|
|
|
"""Get answer for specific question about the image using BLIP"""
|
|
|
inputs = self.blip_processor(image, question, return_tensors="pt").to(self.device)
|
|
|
|
|
|
outputs = self.blip_model.generate(
|
|
|
**inputs,
|
|
|
max_length=50,
|
|
|
num_beams=4,
|
|
|
temperature=0.8,
|
|
|
do_sample=True
|
|
|
)
|
|
|
|
|
|
return self.blip_processor.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
def analyze_image(self, image: Image) -> dict:
|
|
|
"""Comprehensive image analysis using multiple models"""
|
|
|
|
|
|
blip_desc = self.get_blip_description(image)
|
|
|
vit_gpt2_desc = self.get_vit_gpt2_description(image)
|
|
|
|
|
|
|
|
|
standard_questions = [
|
|
|
"What is the main subject of this image?",
|
|
|
"What is the setting or location?",
|
|
|
"What is the lighting and time of day?",
|
|
|
"Are there any people in the image?",
|
|
|
"What activities are happening?",
|
|
|
"What colors are prominent?"
|
|
|
]
|
|
|
|
|
|
qa_results = {}
|
|
|
for question in standard_questions:
|
|
|
qa_results[question] = self.get_visual_qa(image, question)
|
|
|
|
|
|
return {
|
|
|
"blip_description": blip_desc,
|
|
|
"vit_gpt2_description": vit_gpt2_desc,
|
|
|
"detailed_analysis": qa_results
|
|
|
}
|
|
|
|
|
|
def get_chat_response(self, prompt: str, analysis_results: dict) -> str:
|
|
|
"""Generate response using Together AI's Mistral model"""
|
|
|
system_prompt = f"""You are an advanced visual AI assistant analyzing an image.
|
|
|
Image Analysis Results:
|
|
|
1. Primary Description (BLIP): {analysis_results['blip_description']}
|
|
|
2. Secondary Description (ViT-GPT2): {analysis_results['vit_gpt2_description']}
|
|
|
3. Detailed Analysis:
|
|
|
{json.dumps(analysis_results['detailed_analysis'], indent=2)}
|
|
|
|
|
|
Guidelines:
|
|
|
1. Use all available descriptions to provide accurate information.
|
|
|
2. When descriptions differ, mention both perspectives.
|
|
|
3. If asked about details not covered in the analysis, acknowledge the limitation.
|
|
|
4. Maintain a natural, conversational tone while being precise.
|
|
|
5. If there's uncertainty, explain why and what can be confidently stated.
|
|
|
|
|
|
Please respond to the user's query based on this comprehensive analysis.
|
|
|
"""
|
|
|
|
|
|
messages = [
|
|
|
{"role": "system", "content": system_prompt},
|
|
|
{"role": "user", "content": prompt}
|
|
|
]
|
|
|
|
|
|
response = together.Complete.create(
|
|
|
prompt=json.dumps(messages),
|
|
|
model="mistralai/Mistral-7B-Instruct-v0.2",
|
|
|
max_tokens=1024,
|
|
|
temperature=0.7,
|
|
|
top_k=50,
|
|
|
top_p=0.7,
|
|
|
repetition_penalty=1.1
|
|
|
)
|
|
|
|
|
|
|
|
|
if isinstance(response, dict) and 'choices' in response:
|
|
|
raw_text = response['choices'][0]['text'].strip()
|
|
|
|
|
|
|
|
|
if raw_text.startswith('{') or raw_text.startswith('['):
|
|
|
try:
|
|
|
|
|
|
json_obj = json.loads(raw_text)
|
|
|
|
|
|
|
|
|
if isinstance(json_obj, list):
|
|
|
for item in json_obj:
|
|
|
if isinstance(item, dict) and (item.get("role") == "assistant" or item.get("name") == "assistant"):
|
|
|
return item.get("content", "Error: Content not found.")
|
|
|
|
|
|
|
|
|
elif isinstance(json_obj, dict):
|
|
|
if "content" in json_obj:
|
|
|
return json_obj["content"]
|
|
|
elif json_obj.get("role") == "assistant" or json_obj.get("name") == "assistant":
|
|
|
return json_obj.get("content", "Error: Content not found.")
|
|
|
|
|
|
|
|
|
return json.dumps(json_obj, indent=2)
|
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
|
return raw_text
|
|
|
else:
|
|
|
|
|
|
return raw_text
|
|
|
|
|
|
return "Error: Unable to fetch a valid response."
|
|
|
|
|
|
def main():
|
|
|
st.set_page_config(page_title="Multimodal Visual AI Chatbot", layout="wide")
|
|
|
st.title("🤖 Multimodal Visual AI Chatbot")
|
|
|
|
|
|
|
|
|
chatbot = ImprovedVisualChatbot()
|
|
|
|
|
|
|
|
|
with st.sidebar:
|
|
|
st.header("Upload Image")
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
|
|
|
|
|
if uploaded_file is not None:
|
|
|
image = Image.open(uploaded_file)
|
|
|
st.image(image, caption="Uploaded Image", use_container_width=True)
|
|
|
|
|
|
|
|
|
if "analysis_results" not in st.session_state:
|
|
|
with st.spinner("Analyzing image (this may take a moment)..."):
|
|
|
analysis_results = chatbot.analyze_image(image)
|
|
|
st.session_state.analysis_results = analysis_results
|
|
|
|
|
|
|
|
|
st.success("✅ You can now chat with the image!")
|
|
|
|
|
|
|
|
|
st.header("Chat")
|
|
|
|
|
|
|
|
|
for message in st.session_state.messages:
|
|
|
with st.chat_message(message["role"]):
|
|
|
st.write(message["content"])
|
|
|
|
|
|
|
|
|
if prompt := st.chat_input("Ask about the image..."):
|
|
|
if "analysis_results" not in st.session_state:
|
|
|
st.warning("Please upload an image first!")
|
|
|
return
|
|
|
|
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
|
with st.chat_message("user"):
|
|
|
st.write(prompt)
|
|
|
|
|
|
|
|
|
with st.chat_message("assistant"):
|
|
|
with st.spinner("Thinking..."):
|
|
|
response = chatbot.get_chat_response(
|
|
|
prompt,
|
|
|
st.session_state.analysis_results
|
|
|
)
|
|
|
|
|
|
|
|
|
if isinstance(response, list):
|
|
|
response = " ".join(response)
|
|
|
|
|
|
st.write(response)
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|