ganeshkumar383 commited on
Commit
928bc12
Β·
verified Β·
1 Parent(s): 631dbaf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +413 -0
app.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VISUAL CONVERSATIONAL INTELLIGENCE ENGINE
3
+ ==========================================
4
+ A pluggable, image-grounded multi-turn conversational system.
5
+
6
+ Architecture:
7
+ - Session-based image memory (stored once, queried multiple times)
8
+ - Vision-Language Model (BLIP) for image-question answering
9
+ - REST-style core logic (pure functions)
10
+ - Gradio UI for demonstration
11
+
12
+ Academic Purpose:
13
+ Demonstrates AI system design for visual question answering with
14
+ conversational context, suitable for research evaluation.
15
+ """
16
+
17
+ import gradio as gr
18
+ from PIL import Image
19
+ from transformers import BlipProcessor, BlipForQuestionAnswering
20
+ import torch
21
+ from typing import Optional, Tuple, List
22
+ import uuid
23
+
24
+
25
+ # ============================================================================
26
+ # SESSION MEMORY MANAGEMENT
27
+ # ============================================================================
28
+
29
+ class SessionMemory:
30
+ """
31
+ Manages session state for image-grounded conversations.
32
+
33
+ Each session stores:
34
+ - uploaded_image: PIL Image object
35
+ - conversation_history: List of (question, answer) tuples
36
+ - session_id: Unique identifier for the session
37
+ """
38
+
39
+ def __init__(self):
40
+ self.sessions = {}
41
+
42
+ def create_session(self) -> str:
43
+ """Create a new session and return its ID."""
44
+ session_id = str(uuid.uuid4())
45
+ self.sessions[session_id] = {
46
+ 'uploaded_image': None,
47
+ 'conversation_history': []
48
+ }
49
+ return session_id
50
+
51
+ def store_image(self, session_id: str, image: Image.Image) -> None:
52
+ """Store an image in session memory."""
53
+ if session_id in self.sessions:
54
+ self.sessions[session_id]['uploaded_image'] = image
55
+
56
+ def get_image(self, session_id: str) -> Optional[Image.Image]:
57
+ """Retrieve the stored image from session."""
58
+ if session_id in self.sessions:
59
+ return self.sessions[session_id]['uploaded_image']
60
+ return None
61
+
62
+ def add_to_history(self, session_id: str, question: str, answer: str) -> None:
63
+ """Add a Q&A pair to conversation history."""
64
+ if session_id in self.sessions:
65
+ self.sessions[session_id]['conversation_history'].append((question, answer))
66
+
67
+ def get_history(self, session_id: str) -> List[Tuple[str, str]]:
68
+ """Retrieve conversation history."""
69
+ if session_id in self.sessions:
70
+ return self.sessions[session_id]['conversation_history']
71
+ return []
72
+
73
+ def reset_session(self, session_id: str) -> None:
74
+ """Clear all session data (image + conversation history)."""
75
+ if session_id in self.sessions:
76
+ self.sessions[session_id] = {
77
+ 'uploaded_image': None,
78
+ 'conversation_history': []
79
+ }
80
+
81
+
82
+ # ============================================================================
83
+ # VISION-LANGUAGE MODEL INITIALIZATION
84
+ # ============================================================================
85
+
86
+ class VisualQAEngine:
87
+ """
88
+ Core inference engine using BLIP (Bootstrapping Language-Image Pre-training).
89
+
90
+ BLIP is a vision-language model that can answer questions about images.
91
+ We use the pretrained model without any fine-tuning.
92
+ """
93
+
94
+ def __init__(self, model_name: str = "Salesforce/blip-vqa-base"):
95
+ """
96
+ Initialize the BLIP model and processor.
97
+
98
+ Args:
99
+ model_name: HuggingFace model identifier
100
+ """
101
+ print(f"Loading model: {model_name}")
102
+ self.processor = BlipProcessor.from_pretrained(model_name)
103
+ self.model = BlipForQuestionAnswering.from_pretrained(model_name)
104
+
105
+ # Use GPU if available, otherwise CPU (for HuggingFace Spaces compatibility)
106
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
107
+ self.model.to(self.device)
108
+ print(f"Model loaded on device: {self.device}")
109
+
110
+ def answer_question(self, image: Image.Image, question: str) -> str:
111
+ """
112
+ Generate an answer to a question about the image.
113
+
114
+ This is a PURE FUNCTION suitable for REST APIs:
115
+ - Takes image + question as input
116
+ - Returns answer as output
117
+ - No side effects
118
+
119
+ Args:
120
+ image: PIL Image object
121
+ question: User's question about the image
122
+
123
+ Returns:
124
+ Generated answer grounded in the image
125
+ """
126
+ # Preprocess image and question
127
+ inputs = self.processor(image, question, return_tensors="pt").to(self.device)
128
+
129
+ # Generate answer using the vision-language model
130
+ with torch.no_grad():
131
+ outputs = self.model.generate(**inputs, max_length=50)
132
+
133
+ # Decode the generated answer
134
+ answer = self.processor.decode(outputs[0], skip_special_tokens=True)
135
+
136
+ return answer
137
+
138
+
139
+ # ============================================================================
140
+ # APPLICATION LOGIC (REST-STYLE PURE FUNCTIONS)
141
+ # ============================================================================
142
+
143
+ def validate_question(question: str, image: Optional[Image.Image]) -> Tuple[bool, str]:
144
+ """
145
+ Validate that conditions are met for answering a question.
146
+
147
+ Validation rules:
148
+ 1. Image must be uploaded
149
+ 2. Question must not be empty
150
+
151
+ Args:
152
+ question: User's input question
153
+ image: Stored image (or None)
154
+
155
+ Returns:
156
+ (is_valid, error_message)
157
+ """
158
+ if image is None:
159
+ return False, "⚠️ Please upload an image first before asking questions."
160
+
161
+ if not question or question.strip() == "":
162
+ return False, "⚠️ Please enter a question."
163
+
164
+ return True, ""
165
+
166
+
167
+ def process_question(
168
+ vqa_engine: VisualQAEngine,
169
+ session_memory: SessionMemory,
170
+ session_id: str,
171
+ question: str
172
+ ) -> Tuple[str, List[Tuple[str, str]]]:
173
+ """
174
+ Process a user question and generate an image-grounded answer.
175
+
176
+ This function orchestrates the core conversational flow:
177
+ 1. Validate inputs
178
+ 2. Retrieve image from session
179
+ 3. Generate answer using vision-language model
180
+ 4. Update conversation history
181
+ 5. Return answer + updated history
182
+
183
+ Args:
184
+ vqa_engine: Visual QA inference engine
185
+ session_memory: Session storage
186
+ session_id: Current session identifier
187
+ question: User's question
188
+
189
+ Returns:
190
+ (answer, updated_conversation_history)
191
+ """
192
+ # Retrieve stored image
193
+ image = session_memory.get_image(session_id)
194
+
195
+ # Validate inputs
196
+ is_valid, error_msg = validate_question(question, image)
197
+ if not is_valid:
198
+ return error_msg, session_memory.get_history(session_id)
199
+
200
+ # Generate image-grounded answer
201
+ answer = vqa_engine.answer_question(image, question)
202
+
203
+ # Update conversation history
204
+ session_memory.add_to_history(session_id, question, answer)
205
+
206
+ # Return answer and updated history
207
+ return answer, session_memory.get_history(session_id)
208
+
209
+
210
+ def handle_image_upload(
211
+ session_memory: SessionMemory,
212
+ session_id: str,
213
+ image: Image.Image
214
+ ) -> str:
215
+ """
216
+ Handle image upload and store in session memory.
217
+
218
+ Args:
219
+ session_memory: Session storage
220
+ session_id: Current session identifier
221
+ image: Uploaded PIL Image
222
+
223
+ Returns:
224
+ Confirmation message
225
+ """
226
+ if image is None:
227
+ return "⚠️ No image uploaded."
228
+
229
+ # Store image in session
230
+ session_memory.store_image(session_id, image)
231
+
232
+ return "βœ… Image uploaded successfully! You can now ask questions about this image."
233
+
234
+
235
+ def reset_conversation(
236
+ session_memory: SessionMemory,
237
+ session_id: str
238
+ ) -> Tuple[str, List, None]:
239
+ """
240
+ Reset the conversation (clear image and history).
241
+
242
+ Args:
243
+ session_memory: Session storage
244
+ session_id: Current session identifier
245
+
246
+ Returns:
247
+ (status_message, empty_history, None_for_image)
248
+ """
249
+ session_memory.reset_session(session_id)
250
+ return "πŸ”„ Conversation reset. Please upload a new image.", [], None
251
+
252
+
253
+ # ============================================================================
254
+ # GRADIO UI INTERFACE
255
+ # ============================================================================
256
+
257
+ def create_gradio_interface(vqa_engine: VisualQAEngine, session_memory: SessionMemory) -> gr.Blocks:
258
+ """
259
+ Create the Gradio UI for the Visual Conversational Intelligence Engine.
260
+
261
+ UI Components:
262
+ - Image upload
263
+ - Question input
264
+ - Chat history display
265
+ - Reset button
266
+ """
267
+
268
+ with gr.Blocks(title="Visual Conversational Intelligence Engine") as demo:
269
+ # Session state (hidden)
270
+ session_id = gr.State(value=session_memory.create_session())
271
+
272
+ # Header
273
+ gr.Markdown("""
274
+ # πŸ” Visual Conversational Intelligence Engine
275
+
276
+ **An image-grounded multi-turn conversational system**
277
+
278
+ ### How to use:
279
+ 1. **Upload an image** (required)
280
+ 2. **Ask questions** about the image
281
+ 3. **Continue the conversation** - ask follow-up questions without re-uploading
282
+ 4. **Reset** to start over with a new image
283
+
284
+ ### Important:
285
+ - All answers are strictly grounded in the uploaded image
286
+ - Questions unrelated to the image will be politely declined
287
+ - The system uses BLIP (Vision-Language Model) for inference
288
+ """)
289
+
290
+ with gr.Row():
291
+ with gr.Column(scale=1):
292
+ # Image upload section
293
+ gr.Markdown("### πŸ“€ Step 1: Upload Image")
294
+ image_input = gr.Image(
295
+ type="pil",
296
+ label="Upload an image to analyze",
297
+ height=300
298
+ )
299
+ upload_status = gr.Textbox(
300
+ label="Upload Status",
301
+ interactive=False,
302
+ lines=1
303
+ )
304
+
305
+ # Upload button
306
+ upload_btn = gr.Button("πŸ“₯ Upload Image", variant="primary")
307
+
308
+ with gr.Column(scale=1):
309
+ # Question and conversation section
310
+ gr.Markdown("### πŸ’¬ Step 2: Ask Questions")
311
+ chatbot = gr.Chatbot(
312
+ label="Conversation History",
313
+ height=300
314
+ )
315
+ question_input = gr.Textbox(
316
+ label="Your Question",
317
+ placeholder="Ask a question about the uploaded image...",
318
+ lines=2
319
+ )
320
+
321
+ with gr.Row():
322
+ submit_btn = gr.Button("πŸš€ Ask Question", variant="primary")
323
+ reset_btn = gr.Button("πŸ”„ Reset Conversation", variant="secondary")
324
+
325
+ # Event handlers
326
+
327
+ def upload_image_handler(image, session_id):
328
+ """Handle image upload event."""
329
+ status = handle_image_upload(session_memory, session_id, image)
330
+ return status
331
+
332
+ def ask_question_handler(question, session_id):
333
+ """Handle question submission event."""
334
+ answer, history = process_question(vqa_engine, session_memory, session_id, question)
335
+ return history, "" # Return updated history and clear input
336
+
337
+ def reset_handler(session_id):
338
+ """Handle reset button event."""
339
+ status, history, image = reset_conversation(session_memory, session_id)
340
+ return status, history, image
341
+
342
+ # Wire up events
343
+ upload_btn.click(
344
+ fn=upload_image_handler,
345
+ inputs=[image_input, session_id],
346
+ outputs=[upload_status]
347
+ )
348
+
349
+ submit_btn.click(
350
+ fn=ask_question_handler,
351
+ inputs=[question_input, session_id],
352
+ outputs=[chatbot, question_input]
353
+ )
354
+
355
+ question_input.submit(
356
+ fn=ask_question_handler,
357
+ inputs=[question_input, session_id],
358
+ outputs=[chatbot, question_input]
359
+ )
360
+
361
+ reset_btn.click(
362
+ fn=reset_handler,
363
+ inputs=[session_id],
364
+ outputs=[upload_status, chatbot, image_input]
365
+ )
366
+
367
+ # Footer
368
+ gr.Markdown("""
369
+ ---
370
+ **Academic Prototype** | Demonstrates AI system design for visual question answering
371
+
372
+ **Tech Stack:** Python β€’ HuggingFace BLIP β€’ Gradio β€’ Session-based Memory
373
+ """)
374
+
375
+ return demo
376
+
377
+
378
+ # ============================================================================
379
+ # MAIN APPLICATION ENTRY POINT
380
+ # ============================================================================
381
+
382
+ def main():
383
+ """
384
+ Initialize and launch the Visual Conversational Intelligence Engine.
385
+ """
386
+ print("=" * 60)
387
+ print("VISUAL CONVERSATIONAL INTELLIGENCE ENGINE")
388
+ print("=" * 60)
389
+
390
+ # Initialize core components
391
+ print("\n[1/3] Initializing Vision-Language Model...")
392
+ vqa_engine = VisualQAEngine(model_name="Salesforce/blip-vqa-base")
393
+
394
+ print("\n[2/3] Setting up session memory...")
395
+ session_memory = SessionMemory()
396
+
397
+ print("\n[3/3] Creating Gradio interface...")
398
+ demo = create_gradio_interface(vqa_engine, session_memory)
399
+
400
+ print("\n" + "=" * 60)
401
+ print("πŸš€ Launching application...")
402
+ print("=" * 60)
403
+
404
+ # Launch the application
405
+ demo.launch(
406
+ share=False, # Set to True for public sharing
407
+ server_name="0.0.0.0", # Allow external access
408
+ server_port=7860 # Standard Gradio port
409
+ )
410
+
411
+
412
+ if __name__ == "__main__":
413
+ main()