prachi1507 commited on
Commit
c7a78a2
·
verified ·
1 Parent(s): 901d8fe

uoload files

Browse files
Files changed (2) hide show
  1. app.py +249 -0
  2. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nest_asyncio
2
+ nest_asyncio.apply()
3
+
4
+ import streamlit as st
5
+ from transformers import (
6
+ VisionEncoderDecoderModel,
7
+ ViTImageProcessor,
8
+ AutoTokenizer,
9
+ BlipProcessor,
10
+ BlipForConditionalGeneration
11
+ )
12
+ import together
13
+ import torch
14
+ from PIL import Image
15
+ from dotenv import load_dotenv
16
+ import json
17
+ import logging
18
+ logging.getLogger("transformers").setLevel(logging.ERROR)
19
+
20
+ # Load environment variables
21
+ load_dotenv()
22
+
23
+ class ImprovedVisualChatbot:
24
+ def __init__(self):
25
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ # Initialize BLIP model for detailed image understanding
28
+ self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
29
+ self.blip_model = BlipForConditionalGeneration.from_pretrained(
30
+ "Salesforce/blip-image-captioning-large"
31
+ ).to(self.device)
32
+
33
+ # Initialize ViT-GPT2 for additional image captioning
34
+ self.vit_gpt2_model = VisionEncoderDecoderModel.from_pretrained(
35
+ "nlpconnect/vit-gpt2-image-captioning"
36
+ ).to(self.device)
37
+ self.vit_gpt2_feature_extractor = ViTImageProcessor.from_pretrained(
38
+ "nlpconnect/vit-gpt2-image-captioning"
39
+ )
40
+ self.vit_gpt2_tokenizer = AutoTokenizer.from_pretrained(
41
+ "nlpconnect/vit-gpt2-image-captioning"
42
+ )
43
+
44
+ # Initialize session state
45
+ if "messages" not in st.session_state:
46
+ st.session_state.messages = []
47
+
48
+ def get_blip_description(self, image: Image) -> str:
49
+ """Get detailed image description using BLIP model"""
50
+ inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
51
+
52
+ # Generate detailed caption
53
+ outputs = self.blip_model.generate(
54
+ **inputs,
55
+ max_length=100,
56
+ num_beams=5,
57
+ temperature=1.0,
58
+ repetition_penalty=1.2,
59
+ length_penalty=1.0
60
+ )
61
+
62
+ return self.blip_processor.decode(outputs[0], skip_special_tokens=True)
63
+
64
+ def get_vit_gpt2_description(self, image: Image) -> str:
65
+ """Get additional perspective using ViT-GPT2 model"""
66
+ pixel_values = self.vit_gpt2_feature_extractor(
67
+ images=image, return_tensors="pt"
68
+ ).pixel_values.to(self.device)
69
+
70
+ output_ids = self.vit_gpt2_model.generate(
71
+ pixel_values,
72
+ max_length=50,
73
+ num_beams=4,
74
+ temperature=0.8,
75
+ do_sample=True
76
+ )
77
+
78
+ return self.vit_gpt2_tokenizer.decode(output_ids[0], skip_special_tokens=True)
79
+
80
+ def get_visual_qa(self, image: Image, question: str) -> str:
81
+ """Get answer for specific question about the image using BLIP"""
82
+ inputs = self.blip_processor(image, question, return_tensors="pt").to(self.device)
83
+
84
+ outputs = self.blip_model.generate(
85
+ **inputs,
86
+ max_length=50,
87
+ num_beams=4,
88
+ temperature=0.8,
89
+ do_sample=True
90
+ )
91
+
92
+ return self.blip_processor.decode(outputs[0], skip_special_tokens=True)
93
+
94
+ def analyze_image(self, image: Image) -> dict:
95
+ """Comprehensive image analysis using multiple models"""
96
+ # Get descriptions from both models
97
+ blip_desc = self.get_blip_description(image)
98
+ vit_gpt2_desc = self.get_vit_gpt2_description(image)
99
+
100
+ # Get answers to predetermined questions for better understanding
101
+ standard_questions = [
102
+ "What is the main subject of this image?",
103
+ "What is the setting or location?",
104
+ "What is the lighting and time of day?",
105
+ "Are there any people in the image?",
106
+ "What activities are happening?",
107
+ "What colors are prominent?"
108
+ ]
109
+
110
+ qa_results = {}
111
+ for question in standard_questions:
112
+ qa_results[question] = self.get_visual_qa(image, question)
113
+
114
+ return {
115
+ "blip_description": blip_desc,
116
+ "vit_gpt2_description": vit_gpt2_desc,
117
+ "detailed_analysis": qa_results
118
+ }
119
+
120
+ def get_chat_response(self, prompt: str, analysis_results: dict) -> str:
121
+ """Generate response using Together AI's Mistral model"""
122
+ system_prompt = f"""You are an advanced visual AI assistant analyzing an image.
123
+ Image Analysis Results:
124
+ 1. Primary Description (BLIP): {analysis_results['blip_description']}
125
+ 2. Secondary Description (ViT-GPT2): {analysis_results['vit_gpt2_description']}
126
+ 3. Detailed Analysis:
127
+ {json.dumps(analysis_results['detailed_analysis'], indent=2)}
128
+
129
+ Guidelines:
130
+ 1. Use all available descriptions to provide accurate information.
131
+ 2. When descriptions differ, mention both perspectives.
132
+ 3. If asked about details not covered in the analysis, acknowledge the limitation.
133
+ 4. Maintain a natural, conversational tone while being precise.
134
+ 5. If there's uncertainty, explain why and what can be confidently stated.
135
+
136
+ Please respond to the user's query based on this comprehensive analysis.
137
+ """
138
+
139
+ messages = [
140
+ {"role": "system", "content": system_prompt},
141
+ {"role": "user", "content": prompt}
142
+ ]
143
+
144
+ response = together.Complete.create(
145
+ prompt=json.dumps(messages),
146
+ model="mistralai/Mistral-7B-Instruct-v0.2",
147
+ max_tokens=1024,
148
+ temperature=0.7,
149
+ top_k=50,
150
+ top_p=0.7,
151
+ repetition_penalty=1.1
152
+ )
153
+
154
+ # Ensure clean text output
155
+ if isinstance(response, dict) and 'choices' in response:
156
+ raw_text = response['choices'][0]['text'].strip()
157
+
158
+ # If the raw text appears to be JSON (starts with { or [)
159
+ if raw_text.startswith('{') or raw_text.startswith('['):
160
+ try:
161
+ # First, attempt to parse as JSON
162
+ json_obj = json.loads(raw_text)
163
+
164
+ # Case 1: If it's a list of messages like [{"name": "assistant", ...}]
165
+ if isinstance(json_obj, list):
166
+ for item in json_obj:
167
+ if isinstance(item, dict) and (item.get("role") == "assistant" or item.get("name") == "assistant"):
168
+ return item.get("content", "Error: Content not found.")
169
+
170
+ # Case 2: If it's a single message object like {"role": "assistant", ...}
171
+ elif isinstance(json_obj, dict):
172
+ if "content" in json_obj:
173
+ return json_obj["content"]
174
+ elif json_obj.get("role") == "assistant" or json_obj.get("name") == "assistant":
175
+ return json_obj.get("content", "Error: Content not found.")
176
+
177
+ # If we couldn't extract content but it parsed as JSON, return the stringified pretty version
178
+ return json.dumps(json_obj, indent=2)
179
+
180
+ except json.JSONDecodeError:
181
+ # Not valid JSON, return the raw text
182
+ return raw_text
183
+ else:
184
+ # Not JSON format, just return the raw text
185
+ return raw_text
186
+
187
+ return "Error: Unable to fetch a valid response."
188
+
189
+ def main():
190
+ st.set_page_config(page_title="Multimodal Visual AI Chatbot", layout="wide")
191
+ st.title("🤖 Multimodal Visual AI Chatbot")
192
+
193
+ # Initialize chatbot
194
+ chatbot = ImprovedVisualChatbot()
195
+
196
+ # Create sidebar for image upload and analysis details
197
+ with st.sidebar:
198
+ st.header("Upload Image")
199
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
200
+
201
+ if uploaded_file is not None:
202
+ image = Image.open(uploaded_file)
203
+ st.image(image, caption="Uploaded Image", use_container_width=True)
204
+
205
+ # Analyze image
206
+ if "analysis_results" not in st.session_state:
207
+ with st.spinner("Analyzing image (this may take a moment)..."):
208
+ analysis_results = chatbot.analyze_image(image)
209
+ st.session_state.analysis_results = analysis_results
210
+
211
+ # Display a message after successful analysis
212
+ st.success("✅ You can now chat with the image!")
213
+
214
+ # Main chat interface
215
+ st.header("Chat")
216
+
217
+ # Display chat messages
218
+ for message in st.session_state.messages:
219
+ with st.chat_message(message["role"]):
220
+ st.write(message["content"])
221
+
222
+ # Chat input
223
+ if prompt := st.chat_input("Ask about the image..."):
224
+ if "analysis_results" not in st.session_state:
225
+ st.warning("Please upload an image first!")
226
+ return
227
+
228
+ # Add user message to chat
229
+ st.session_state.messages.append({"role": "user", "content": prompt})
230
+ with st.chat_message("user"):
231
+ st.write(prompt)
232
+
233
+ # Get chatbot response
234
+ with st.chat_message("assistant"):
235
+ with st.spinner("Thinking..."):
236
+ response = chatbot.get_chat_response(
237
+ prompt,
238
+ st.session_state.analysis_results
239
+ )
240
+
241
+ # Ensure the response is a string (handle list output issue)
242
+ if isinstance(response, list):
243
+ response = " ".join(response)
244
+
245
+ st.write(response)
246
+ st.session_state.messages.append({"role": "assistant", "content": response})
247
+
248
+ if __name__ == "__main__":
249
+ main()
requirements.txt ADDED
Binary file (3.03 kB). View file