abhiman181025 commited on
Commit
1314bf5
Β·
1 Parent(s): 14ebc7f

First commit

Browse files
README copy.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Prompt Pilot
3
+ emoji: πŸ“Š
4
+ colorFrom: gray
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.34.2
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: test
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ InternVL3 Prompt Engineering Application
4
+ Entry point for the modular InternVL3 image analysis application.
5
+ """
6
+
7
+ from frontend.gradio_app import GradioApp
8
+
9
+ if __name__ == "__main__":
10
+ # Create and launch the application
11
+ app = GradioApp()
12
+ app.launch()
backend/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .config import ConfigManager
2
+ from .models import ModelManager, InternVLModel, BaseModel
3
+ from .inference import InferenceEngine
4
+ from .utils import (
5
+ build_transform,
6
+ load_image,
7
+ extract_file_dict,
8
+ validate_data,
9
+ extract_binary_output,
10
+ create_accuracy_table,
11
+ save_dataframe_to_csv
12
+ )
13
+
14
+ __all__ = [
15
+ 'ConfigManager',
16
+ 'ModelManager',
17
+ 'InternVLModel',
18
+ 'BaseModel',
19
+ 'InferenceEngine',
20
+ 'build_transform',
21
+ 'load_image',
22
+ 'extract_file_dict',
23
+ 'validate_data',
24
+ 'extract_binary_output',
25
+ 'create_accuracy_table',
26
+ 'save_dataframe_to_csv'
27
+ ]
backend/config/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .config_manager import ConfigManager
2
+
3
+ __all__ = ['ConfigManager']
backend/config/config_manager.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Dict, List, Any, Optional
5
+
6
+ class ConfigManager:
7
+ """Manages configuration loading and access for the application."""
8
+
9
+ def __init__(self, config_path: Optional[str] = None):
10
+ """
11
+ Initialize the configuration manager.
12
+
13
+ Args:
14
+ config_path: Path to the configuration file. If None, uses default path.
15
+ """
16
+ if config_path is None:
17
+ # Default to config/models.yaml relative to project root
18
+ project_root = Path(__file__).parent.parent.parent
19
+ config_path = project_root / "config" / "models.yaml"
20
+
21
+ self.config_path = Path(config_path)
22
+ self._config = None
23
+ self.load_config()
24
+
25
+ def load_config(self) -> None:
26
+ """Load configuration from YAML file."""
27
+ try:
28
+ with open(self.config_path, 'r', encoding='utf-8') as file:
29
+ self._config = yaml.safe_load(file)
30
+ print(f"βœ… Configuration loaded from {self.config_path}")
31
+ except FileNotFoundError:
32
+ raise FileNotFoundError(f"Configuration file not found: {self.config_path}")
33
+ except yaml.YAMLError as e:
34
+ raise ValueError(f"Invalid YAML in configuration file: {e}")
35
+
36
+ def reload_config(self) -> None:
37
+ """Reload configuration from file."""
38
+ self.load_config()
39
+
40
+ @property
41
+ def config(self) -> Dict[str, Any]:
42
+ """Get the full configuration dictionary."""
43
+ if self._config is None:
44
+ self.load_config()
45
+ return self._config
46
+
47
+ def get_available_models(self) -> Dict[str, str]:
48
+ """Get a dictionary of available model names and their IDs."""
49
+ models = self.config.get('models', {})
50
+ return {name: model_config['model_id'] for name, model_config in models.items()}
51
+
52
+ def get_model_config(self, model_name: str) -> Dict[str, Any]:
53
+ """
54
+ Get configuration for a specific model.
55
+
56
+ Args:
57
+ model_name: Name of the model (e.g., 'InternVL3-8B')
58
+
59
+ Returns:
60
+ Model configuration dictionary
61
+
62
+ Raises:
63
+ KeyError: If model name is not found
64
+ """
65
+ models = self.config.get('models', {})
66
+ if model_name not in models:
67
+ available = list(models.keys())
68
+ raise KeyError(f"Model '{model_name}' not found. Available models: {available}")
69
+
70
+ return models[model_name]
71
+
72
+ def get_supported_quantizations(self, model_name: str) -> List[str]:
73
+ """Get supported quantization methods for a model."""
74
+ model_config = self.get_model_config(model_name)
75
+ return model_config.get('supported_quantizations', [])
76
+
77
+ def get_default_quantization(self, model_name: str) -> str:
78
+ """Get the default quantization method for a model."""
79
+ model_config = self.get_model_config(model_name)
80
+ return model_config.get('default_quantization', 'non-quantized(fp16)')
81
+
82
+ def get_default_model(self) -> str:
83
+ """Get the default model name."""
84
+ return self.config.get('default_model', 'InternVL3-8B')
85
+
86
+ def validate_model_and_quantization(self, model_name: str, quantization: str) -> bool:
87
+ """
88
+ Validate if a quantization method is supported for a model.
89
+
90
+ Args:
91
+ model_name: Name of the model
92
+ quantization: Quantization method
93
+
94
+ Returns:
95
+ True if valid, False otherwise
96
+ """
97
+ try:
98
+ supported = self.get_supported_quantizations(model_name)
99
+ return quantization in supported
100
+ except KeyError:
101
+ return False
102
+
103
+ def apply_environment_settings(self) -> None:
104
+ """Apply environment settings to the current process."""
105
+ # Set CUDA memory allocation
106
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
107
+
108
+ def get_model_description(self, model_name: str) -> str:
109
+ """Get description for a model."""
110
+ model_config = self.get_model_config(model_name)
111
+ return model_config.get('description', 'No description available')
112
+
113
+ def __str__(self) -> str:
114
+ """String representation of the configuration manager."""
115
+ return f"ConfigManager(config_path={self.config_path})"
116
+
117
+ def __repr__(self) -> str:
118
+ """Detailed string representation."""
119
+ models = list(self.get_available_models().keys())
120
+ return f"ConfigManager(config_path={self.config_path}, models={models})"
backend/inference/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .inference_engine import InferenceEngine
2
+
3
+ __all__ = ['InferenceEngine']
backend/inference/inference_engine.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import threading
3
+ import time
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Dict, List, Tuple, Union, Any, Optional, Callable
7
+ import gradio as gr
8
+ from ..models.model_manager import ModelManager
9
+ from ..utils.data_processing import extract_file_dict, validate_data, extract_binary_output
10
+ from ..config.config_manager import ConfigManager
11
+ from ..utils.metrics import create_accuracy_table
12
+ from datetime import datetime
13
+ import boto3
14
+
15
+
16
+ class InferenceEngine:
17
+ """Engine for handling batch inference and processing control."""
18
+
19
+ def __init__(self, model_manager: ModelManager, config_manager: ConfigManager):
20
+ """
21
+ Initialize the inference engine.
22
+
23
+ Args:
24
+ model_manager: Model manager instance
25
+ config_manager: Configuration manager instance
26
+ """
27
+ self.model_manager = model_manager
28
+ self.config_manager = config_manager
29
+ self.processing_lock = threading.Lock()
30
+ self.stop_processing = False
31
+ self.full_df = None # Store full dataframe with image paths
32
+
33
+ def set_stop_flag(self) -> str:
34
+ """Set the global stop flag to interrupt processing."""
35
+ with self.processing_lock:
36
+ self.stop_processing = True
37
+ print("πŸ›‘ Stop signal received. Processing will halt after current image...")
38
+ return "πŸ›‘ Stopping process... Please wait for current image to complete."
39
+
40
+ def reset_stop_flag(self) -> None:
41
+ """Reset the global stop flag before starting new processing."""
42
+ with self.processing_lock:
43
+ self.stop_processing = False
44
+
45
+ def check_stop_flag(self) -> bool:
46
+ """Check if processing should be stopped."""
47
+ with self.processing_lock:
48
+ return self.stop_processing
49
+
50
+ def _should_load_model(self, model_selection: str, quantization_type: str) -> bool:
51
+ """
52
+ Check if we need to load the model.
53
+
54
+ Args:
55
+ model_selection: Selected model name
56
+ quantization_type: Selected quantization type
57
+
58
+ Returns:
59
+ True if model needs to be loaded, False otherwise
60
+ """
61
+ # If no model is loaded, we need to load
62
+ if not self.model_manager.current_model or not self.model_manager.current_model.is_model_loaded():
63
+ return True
64
+
65
+ # If different model is selected, we need to load
66
+ if self.model_manager.current_model_name != model_selection:
67
+ return True
68
+
69
+ # If same model but different quantization, we need to reload
70
+ if self.model_manager.current_model.current_quantization != quantization_type:
71
+ return True
72
+
73
+ return False
74
+
75
+ def _ensure_correct_model_loaded(self, model_selection: str, quantization_type: str, progress: gr.Progress()) -> None:
76
+ """
77
+ Ensure the correct model with correct quantization is loaded.
78
+
79
+ Args:
80
+ model_selection: Selected model name
81
+ quantization_type: Selected quantization type
82
+ progress: Gradio progress object
83
+ """
84
+ if self._should_load_model(model_selection, quantization_type):
85
+ progress(0, desc=f"πŸš€ Loading {model_selection} ({quantization_type})...")
86
+ print(f"πŸš€ Loading {model_selection} with {quantization_type}...")
87
+ success = self.model_manager.load_model(model_selection, quantization_type)
88
+ if not success:
89
+ raise Exception(f"Failed to load model {model_selection} with {quantization_type}")
90
+ else:
91
+ print(f"βœ… Correct model already loaded: {model_selection} with {quantization_type}")
92
+
93
+ def process_folder_input(
94
+ self,
95
+ folder_path: List[Path],
96
+ prompt: str,
97
+ quantization_type: str,
98
+ model_selection: str,
99
+ progress: gr.Progress()
100
+ ) -> Tuple[Any, ...]:
101
+ """
102
+ Process input folder with images and optional CSV.
103
+
104
+ Args:
105
+ folder_path: List of Path objects from Gradio
106
+ prompt: Text prompt for inference
107
+ quantization_type: Model quantization type
108
+ model_selection: Selected model name
109
+ progress: Gradio progress object
110
+
111
+ Returns:
112
+ Tuple of UI update states and results
113
+ """
114
+ # Reset stop flag at the beginning of processing
115
+ self.reset_stop_flag()
116
+
117
+ # Extract file dictionary
118
+ file_dict = extract_file_dict(folder_path)
119
+
120
+ # Print all file names for debug
121
+ for fname in file_dict:
122
+ print(fname)
123
+
124
+ validation_result, message = validate_data(file_dict)
125
+
126
+ # Handle different validation results
127
+ if validation_result == False:
128
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), message, gr.update(visible=False), ""
129
+ elif validation_result in ["no_csv", "multiple_csv"]:
130
+ return self._process_without_csv(file_dict, prompt, quantization_type, model_selection, progress)
131
+ else:
132
+ return self._process_with_csv(file_dict, prompt, quantization_type, model_selection, progress)
133
+
134
+ def _process_without_csv(
135
+ self,
136
+ file_dict: Dict[str, Path],
137
+ prompt: str,
138
+ quantization_type: str,
139
+ model_selection: str,
140
+ progress: gr.Progress()
141
+ ) -> Tuple[Any, ...]:
142
+ """Process images without CSV file."""
143
+ image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']
144
+ image_file_dict = {fname: file_dict[fname] for fname in file_dict
145
+ if any(fname.lower().endswith(ext) for ext in image_exts)}
146
+
147
+ filtered_rows = []
148
+ total_images = len(image_file_dict)
149
+
150
+ if total_images == 0:
151
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), "No image files found.", gr.update(visible=False), ""
152
+
153
+ # Ensure correct model is loaded
154
+ self._ensure_correct_model_loaded(model_selection, quantization_type, progress)
155
+
156
+ # Initialize progress
157
+ progress(0, desc=f"πŸš€ Starting to process {total_images} images...")
158
+ print(f"Starting to process {total_images} images with {model_selection}...")
159
+
160
+ for idx, (img_name, img_path) in enumerate(image_file_dict.items()):
161
+ # Check stop flag before processing each image
162
+ if self.check_stop_flag():
163
+ print(f"πŸ›‘ Processing stopped by user at image {idx + 1}/{total_images}")
164
+ # Add remaining images as "Not processed" entries
165
+ for remaining_idx, (remaining_name, remaining_path) in enumerate(list(image_file_dict.items())[idx:]):
166
+ filtered_rows.append({
167
+ 'S.No': idx + remaining_idx + 1,
168
+ 'Image Name': remaining_name,
169
+ 'Ground Truth': '',
170
+ 'Binary Output': 'Not processed (stopped)',
171
+ 'Model Output': 'Processing stopped by user',
172
+ 'Image Path': str(remaining_path)
173
+ })
174
+
175
+ display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']]
176
+ self.full_df = pd.DataFrame(filtered_rows)
177
+ final_message = f"πŸ›‘ Processing stopped by user. Completed {idx}/{total_images} images."
178
+ print(final_message)
179
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message
180
+
181
+ try:
182
+ # Update progress with current image info
183
+ current_progress = idx / total_images
184
+ progress_msg = f"πŸ”„ Processing image {idx + 1}/{total_images}: {img_name[:30]}..." if len(img_name) > 30 else f"πŸ”„ Processing image {idx + 1}/{total_images}: {img_name}"
185
+ progress(current_progress, desc=progress_msg)
186
+ print(progress_msg)
187
+
188
+ # Use model inference
189
+ model_output = self.model_manager.inference(str(img_path), prompt) if prompt else "No prompt provided"
190
+
191
+ # Extract binary output (no ground truth available for file-based processing)
192
+ binary_output = extract_binary_output(model_output, "", [])
193
+
194
+ filtered_rows.append({
195
+ 'S.No': idx + 1,
196
+ 'Image Name': img_name,
197
+ 'Ground Truth': '', # Empty for manual input
198
+ 'Binary Output': binary_output,
199
+ 'Model Output': model_output,
200
+ 'Image Path': str(img_path)
201
+ })
202
+
203
+ # Update progress after successful processing
204
+ current_progress = (idx + 1) / total_images
205
+ progress_msg = f"βœ… Completed {idx + 1}/{total_images} images"
206
+ progress(current_progress, desc=progress_msg)
207
+ print(f"Successfully processed image {idx + 1} of {total_images}")
208
+
209
+ except Exception as e:
210
+ print(f"Error processing image {idx + 1} of {total_images}: {str(e)}")
211
+ filtered_rows.append({
212
+ 'S.No': idx + 1,
213
+ 'Image Name': img_name,
214
+ 'Ground Truth': '',
215
+ 'Binary Output': 'Enter the output manually', # Default for errors
216
+ 'Model Output': f"Error: {str(e)}",
217
+ 'Image Path': str(img_path)
218
+ })
219
+
220
+ # Update progress even for errors
221
+ current_progress = (idx + 1) / total_images
222
+ progress_msg = f"⚠️ Processed {idx + 1}/{total_images} images (with errors)"
223
+ progress(current_progress, desc=progress_msg)
224
+
225
+ # Check if processing was completed or stopped
226
+ if self.check_stop_flag():
227
+ final_message = f"πŸ›‘ Processing stopped by user. Completed {len(filtered_rows)}/{total_images} images."
228
+ else:
229
+ final_message = f"πŸŽ‰ Successfully completed processing all {total_images} images!"
230
+
231
+ display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']]
232
+ # Save the full dataframe (with Image Path) for preview
233
+ self.full_df = pd.DataFrame(filtered_rows)
234
+ self.save_results_to_s3(display_df)
235
+
236
+ print(final_message)
237
+
238
+ # Make the table editable for ground truth input
239
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message
240
+
241
+ def _process_with_csv(
242
+ self,
243
+ file_dict: Dict[str, Path],
244
+ prompt: str,
245
+ quantization_type: str,
246
+ model_selection: str,
247
+ progress: gr.Progress()
248
+ ) -> Tuple[Any, ...]:
249
+ """Process images with CSV file."""
250
+ csv_files = [fname for fname in file_dict if fname.lower().endswith('.csv')]
251
+ csv_file = file_dict[csv_files[0]]
252
+ df = pd.read_csv(csv_file)
253
+
254
+ # Collect all ground truth values for unique keyword extraction
255
+ all_ground_truths = [str(row['Ground Truth']) for idx, row in df.iterrows()
256
+ if pd.notna(row['Ground Truth']) and str(row['Ground Truth']).strip()]
257
+
258
+ # Find image files
259
+ image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']
260
+ image_file_dict = {fname: file_dict[fname] for fname in file_dict
261
+ if any(fname.lower().endswith(ext) for ext in image_exts)}
262
+
263
+ # Only keep rows where image file exists
264
+ filtered_rows = []
265
+ matching_images = [row for idx, row in df.iterrows() if row['Image Name'] in image_file_dict]
266
+ total_images = len(matching_images)
267
+
268
+ if total_images == 0:
269
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), "No matching images found for entries in CSV.", gr.update(visible=False), ""
270
+
271
+ # Ensure correct model is loaded
272
+ self._ensure_correct_model_loaded(model_selection, quantization_type, progress)
273
+
274
+ # Initialize progress
275
+ progress(0, desc=f"πŸš€ Starting to process {total_images} images...")
276
+ print(f"Starting to process {total_images} images with {model_selection}...")
277
+ processed_count = 0
278
+
279
+ for idx, row in df.iterrows():
280
+ img_name = row['Image Name']
281
+ if img_name in image_file_dict:
282
+ # Check stop flag before processing each image
283
+ if self.check_stop_flag():
284
+ print(f"πŸ›‘ Processing stopped by user at image {processed_count + 1}/{total_images}")
285
+ # Add remaining unprocessed images
286
+ for remaining_idx, remaining_row in df.iloc[idx:].iterrows():
287
+ if remaining_row['Image Name'] in image_file_dict:
288
+ filtered_rows.append({
289
+ 'S.No': len(filtered_rows) + 1,
290
+ 'Image Name': remaining_row['Image Name'],
291
+ 'Ground Truth': remaining_row['Ground Truth'],
292
+ 'Binary Output': 'Not processed (stopped)',
293
+ 'Model Output': 'Processing stopped by user',
294
+ 'Image Path': str(image_file_dict[remaining_row['Image Name']])
295
+ })
296
+
297
+ display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']]
298
+ self.full_df = pd.DataFrame(filtered_rows)
299
+ final_message = f"πŸ›‘ Processing stopped by user. Completed {processed_count}/{total_images} images."
300
+ print(final_message)
301
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message
302
+
303
+ try:
304
+ processed_count += 1
305
+ # Update progress with current image info
306
+ current_progress = (processed_count - 1) / total_images
307
+ progress_msg = f"πŸ”„ Processing image {processed_count}/{total_images}: {img_name[:30]}..." if len(img_name) > 30 else f"πŸ”„ Processing image {processed_count}/{total_images}: {img_name}"
308
+ progress(current_progress, desc=progress_msg)
309
+ print(progress_msg)
310
+
311
+ # Use model inference
312
+ model_output = self.model_manager.inference(str(image_file_dict[img_name]), prompt)
313
+
314
+ # Extract binary output using ground truth and all ground truths for keyword extraction
315
+ ground_truth = str(row['Ground Truth']) if pd.notna(row['Ground Truth']) else ""
316
+ binary_output = extract_binary_output(model_output, ground_truth, all_ground_truths)
317
+
318
+ filtered_rows.append({
319
+ 'S.No': len(filtered_rows) + 1,
320
+ 'Image Name': img_name,
321
+ 'Ground Truth': row['Ground Truth'],
322
+ 'Binary Output': binary_output,
323
+ 'Model Output': model_output,
324
+ 'Image Path': str(image_file_dict[img_name])
325
+ })
326
+
327
+ # Update progress after successful processing
328
+ current_progress = processed_count / total_images
329
+ progress_msg = f"βœ… Completed {processed_count}/{total_images} images"
330
+ progress(current_progress, desc=progress_msg)
331
+ print(f"Successfully processed image {processed_count} of {total_images}")
332
+
333
+ except Exception as e:
334
+ print(f"Error processing image {processed_count} of {total_images}: {str(e)}")
335
+ filtered_rows.append({
336
+ 'S.No': len(filtered_rows) + 1,
337
+ 'Image Name': img_name,
338
+ 'Ground Truth': row['Ground Truth'],
339
+ 'Binary Output': 'Enter the output manually', # Default for errors
340
+ 'Model Output': f"Error: {str(e)}",
341
+ 'Image Path': str(image_file_dict[img_name])
342
+ })
343
+
344
+ # Update progress even for errors
345
+ current_progress = processed_count / total_images
346
+ progress_msg = f"⚠️ Processed {processed_count}/{total_images} images (with errors)"
347
+ progress(current_progress, desc=progress_msg)
348
+
349
+ # Check if processing was completed or stopped
350
+ if self.check_stop_flag():
351
+ final_message = f"πŸ›‘ Processing stopped by user. Completed {len([r for r in filtered_rows if 'stopped' not in r['Model Output']])}/{total_images} images."
352
+ else:
353
+ final_message = f"πŸŽ‰ Successfully completed processing all {total_images} images!"
354
+
355
+ display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']]
356
+ # Save the full dataframe (with Image Path) for preview
357
+ self.full_df = pd.DataFrame(filtered_rows)
358
+
359
+ self.save_results_to_s3(display_df)
360
+
361
+ print(final_message)
362
+
363
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message
364
+
365
+ def rerun_with_new_prompt(
366
+ self,
367
+ df: pd.DataFrame,
368
+ new_prompt: str,
369
+ quantization_type: str,
370
+ model_selection: str,
371
+ progress: gr.Progress()
372
+ ) -> Tuple[Any, ...]:
373
+ """Rerun processing with new prompt and clear accuracy data."""
374
+ if df is None or not new_prompt.strip():
375
+ return df, None, None, None, gr.update(visible=False), gr.update(visible=False), "⚠️ Please provide a valid prompt"
376
+
377
+ # Reset stop flag at the beginning of reprocessing
378
+ self.reset_stop_flag()
379
+
380
+ updated_df = df.copy()
381
+ total_images = len(updated_df)
382
+
383
+ # Collect all ground truth values for unique keyword extraction
384
+ all_ground_truths = [str(row['Ground Truth']) for idx, row in updated_df.iterrows()
385
+ if pd.notna(row['Ground Truth']) and str(row['Ground Truth']).strip()]
386
+
387
+ # Get the full dataframe with image paths
388
+ if self.full_df is None:
389
+ return df, None, None, None, gr.update(visible=False), gr.update(visible=False), "⚠️ No image data available"
390
+
391
+ # Create a copy of the full dataframe to update
392
+ updated_full_df = self.full_df.copy()
393
+
394
+ # Ensure correct model is loaded
395
+ self._ensure_correct_model_loaded(model_selection, quantization_type, progress)
396
+
397
+ # Initialize progress
398
+ progress(0, desc=f"πŸš€ Starting to reprocess {total_images} images with new prompt...")
399
+ print(f"πŸš€ Starting to reprocess {total_images} images with new prompt...")
400
+
401
+ for i in range(len(updated_df)):
402
+ # Check stop flag before processing each image
403
+ if self.check_stop_flag():
404
+ print(f"πŸ›‘ Reprocessing stopped by user at image {i + 1}/{total_images}")
405
+ # Mark remaining images as not reprocessed in both dataframes
406
+ for j in range(i, len(updated_df)):
407
+ updated_df.iloc[j, updated_df.columns.get_loc("Model Output")] = "Reprocessing stopped by user"
408
+ updated_df.iloc[j, updated_df.columns.get_loc("Binary Output")] = "Not reprocessed (stopped)"
409
+ # Also update the full dataframe
410
+ if j < len(updated_full_df):
411
+ updated_full_df.iloc[j, updated_full_df.columns.get_loc("Model Output")] = "Reprocessing stopped by user"
412
+ updated_full_df.iloc[j, updated_full_df.columns.get_loc("Binary Output")] = "Not reprocessed (stopped)"
413
+
414
+ # Update the full_df reference
415
+ self.full_df = updated_full_df
416
+
417
+ final_message = f"πŸ›‘ Reprocessing stopped by user. Completed {i}/{total_images} images."
418
+ print(final_message)
419
+ return updated_df, None, None, None, gr.update(visible=False), gr.update(visible=False), final_message
420
+
421
+ try:
422
+ # Get image path from full_df
423
+ image_path = self.full_df.iloc[i]['Image Path']
424
+ image_name = updated_df.iloc[i]['Image Name']
425
+ ground_truth = str(updated_df.iloc[i]['Ground Truth']) if pd.notna(updated_df.iloc[i]['Ground Truth']) else ""
426
+
427
+ # Update progress with current image info
428
+ current_progress = i / total_images
429
+ progress_msg = f"πŸ”„ Reprocessing image {i + 1}/{total_images}: {image_name[:30]}..." if len(image_name) > 30 else f"πŸ”„ Reprocessing image {i + 1}/{total_images}: {image_name}"
430
+ progress(current_progress, desc=progress_msg)
431
+ print(progress_msg)
432
+
433
+ # Use model inference with new prompt
434
+ model_output = self.model_manager.inference(image_path, new_prompt)
435
+
436
+ # Update both the display dataframe and the full dataframe
437
+ updated_df.iloc[i, updated_df.columns.get_loc("Model Output")] = model_output
438
+ updated_full_df.iloc[i, updated_full_df.columns.get_loc("Model Output")] = model_output
439
+
440
+ # Extract binary output using ground truth and all ground truths for keyword extraction
441
+ binary_output = extract_binary_output(model_output, ground_truth, all_ground_truths)
442
+ updated_df.iloc[i, updated_df.columns.get_loc("Binary Output")] = binary_output
443
+ updated_full_df.iloc[i, updated_full_df.columns.get_loc("Binary Output")] = binary_output
444
+
445
+ # Update progress after successful processing
446
+ current_progress = (i + 1) / total_images
447
+ progress_msg = f"βœ… Completed {i + 1}/{total_images} images"
448
+ progress(current_progress, desc=progress_msg)
449
+ print(f"βœ… Successfully reprocessed image {i + 1}/{total_images}")
450
+
451
+ except Exception as e:
452
+ print(f"❌ Error reprocessing image {i + 1}/{total_images}: {str(e)}")
453
+ error_message = f"Error: {str(e)}"
454
+
455
+ # Update both dataframes with error information
456
+ updated_df.iloc[i, updated_df.columns.get_loc("Model Output")] = error_message
457
+ updated_df.iloc[i, updated_df.columns.get_loc("Binary Output")] = "Enter the output manually"
458
+ updated_full_df.iloc[i, updated_full_df.columns.get_loc("Model Output")] = error_message
459
+ updated_full_df.iloc[i, updated_full_df.columns.get_loc("Binary Output")] = "Enter the output manually"
460
+
461
+ # Update progress even for errors
462
+ current_progress = (i + 1) / total_images
463
+ progress_msg = f"⚠️ Processed {i + 1}/{total_images} images (with errors)"
464
+ progress(current_progress, desc=progress_msg)
465
+
466
+ # Update the full_df reference with the updated data
467
+ self.full_df = updated_full_df
468
+
469
+ # Check if reprocessing was completed or stopped
470
+ if self.check_stop_flag():
471
+ final_message = f"πŸ›‘ Reprocessing stopped by user. Completed reprocessing for some images."
472
+ else:
473
+ final_message = f"πŸŽ‰ Successfully completed reprocessing all {total_images} images with new prompt! Click 'Generate Metrics' to see accuracy data."
474
+ self.save_results_to_s3(updated_full_df)
475
+
476
+ print(final_message)
477
+
478
+ # Return updated dataframe and clear accuracy data (hide section 3)
479
+ return updated_df, None, None, None, gr.update(visible=False), gr.update(visible=False), final_message
480
+
481
+ def save_results_to_s3(self, df):
482
+ """Save results to S3 bucket."""
483
+ try:
484
+ s3_bucket = os.getenv('AWS_BUCKET')
485
+ prefix = os.getenv('AWS_PREFIX')
486
+ s3_path = f"{prefix}/{datetime.now().date()}"
487
+ date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
488
+ csv_file_name = f'{date_time}_model_output.csv'
489
+
490
+ # create accuracy table
491
+ metrics_df, _, cm_values = create_accuracy_table(df)
492
+ # save metrics_df to text file
493
+
494
+ text_file_name = f'{date_time}_evaluation_metrics.txt'
495
+ # save metrics_df to text file
496
+ with open(text_file_name, 'w') as f:
497
+ f.write(metrics_df.to_string() + '\n\n')
498
+ f.write(cm_values.to_string())
499
+
500
+ # save df to csv
501
+ df.to_csv(csv_file_name, index=False)
502
+
503
+ # upload files to s3
504
+ status = self.upload_file(text_file_name, s3_bucket, f"{s3_path}/{text_file_name}")
505
+ print(f"Status of uploading {text_file_name} to {s3_bucket}/{s3_path}/{text_file_name}: {status}")
506
+ status = self.upload_file(csv_file_name, s3_bucket, f"{s3_path}/{csv_file_name}")
507
+ print(f"Status of uploading {csv_file_name} to {s3_bucket}/{s3_path}/{csv_file_name}: {status}")
508
+
509
+ # delete files from local
510
+ os.remove(text_file_name)
511
+ os.remove(csv_file_name)
512
+ print(f"Deleted {text_file_name} and {csv_file_name}")
513
+ except Exception as e:
514
+ print(f"Error saving results to s3: {e}")
515
+ if "No valid data" in str(e) or "Need at least 2 different" in str(e):
516
+ df.to_csv(csv_file_name, index=False)
517
+ status = self.upload_file(csv_file_name, s3_bucket, f"{s3_path}/{csv_file_name}")
518
+ print(f"Status of uploading only csv file to {s3_bucket}/{s3_path}/{csv_file_name}: {status}")
519
+ os.remove(csv_file_name)
520
+ print(f"Deleted {csv_file_name}")
521
+
522
+ def upload_file(self,file_name, bucket, object_name=None):
523
+ """Upload a file to an S3 bucket
524
+
525
+ :param file_name: File to upload
526
+ :param bucket: Bucket to upload to
527
+ :param object_name: S3 object name. If not specified then file_name is used
528
+ :return: True if file was uploaded, else False
529
+ """
530
+ access_key = os.getenv('AWS_ACCESS_KEY_ID')
531
+ secret_key = os.getenv('AWS_SECRET_ACCESS_KEY')
532
+ # If S3 object_name was not specified, use file_name
533
+ if object_name is None:
534
+ object_name = os.path.basename(file_name)
535
+
536
+ # Upload the file
537
+ s3_client = boto3.client('s3', aws_access_key_id=access_key, aws_secret_access_key=secret_key)
538
+ try:
539
+ response = s3_client.upload_file(file_name, bucket, object_name)
540
+ except Exception as e:
541
+ print(f"Error uploading {file_name} to s3: {e}")
542
+ return False
543
+ return True
backend/models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .base_model import BaseModel
2
+ from .model_manager import ModelManager
3
+ from .internvl import InternVLModel
4
+ from .qwen import QwenModel
5
+
6
+ __all__ = ['BaseModel', 'ModelManager', 'InternVLModel', 'QwenModel']
backend/models/base_model.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, Optional, List
3
+ import torch
4
+ from transformers import AutoModel, AutoTokenizer
5
+
6
+ class BaseModel(ABC):
7
+ """Abstract base class for all vision-language models."""
8
+
9
+ def __init__(self, model_name: str, model_config: Dict[str, Any]):
10
+ """
11
+ Initialize the base model.
12
+
13
+ Args:
14
+ model_name: Name of the model
15
+ model_config: Configuration dictionary for the model
16
+ """
17
+ self.model_name = model_name
18
+ self.model_config = model_config
19
+ self.model_id = model_config['model_id']
20
+ self.model = None
21
+ self.tokenizer = None
22
+ self.current_quantization = None
23
+ self.is_loaded = False
24
+
25
+ @abstractmethod
26
+ def load_model(self, quantization_type: str, **kwargs) -> bool:
27
+ """
28
+ Load the model with specified quantization.
29
+
30
+ Args:
31
+ quantization_type: Type of quantization to use
32
+ **kwargs: Additional arguments for model loading
33
+
34
+ Returns:
35
+ True if successful, False otherwise
36
+ """
37
+ pass
38
+
39
+ @abstractmethod
40
+ def unload_model(self) -> None:
41
+ """Unload the model from memory."""
42
+ pass
43
+
44
+ @abstractmethod
45
+ def inference(self, image_path: str, prompt: str, **kwargs) -> str:
46
+ """
47
+ Perform inference on an image with a text prompt.
48
+
49
+ Args:
50
+ image_path: Path to the image file
51
+ prompt: Text prompt for the model
52
+ **kwargs: Additional inference parameters
53
+
54
+ Returns:
55
+ Model's text response
56
+ """
57
+ pass
58
+
59
+ def is_model_loaded(self) -> bool:
60
+ """Check if model is currently loaded."""
61
+ return self.is_loaded
62
+
63
+ def get_model_info(self) -> Dict[str, Any]:
64
+ """Get information about the model."""
65
+ return {
66
+ 'name': self.model_name,
67
+ 'model_id': self.model_id,
68
+ 'description': self.model_config.get('description', ''),
69
+ 'min_gpu_memory_gb': self.model_config.get('min_gpu_memory_gb', 0),
70
+ 'recommended_gpu_memory_gb': self.model_config.get('recommended_gpu_memory_gb', 0),
71
+ 'supported_quantizations': self.model_config.get('supported_quantizations', []),
72
+ 'default_quantization': self.model_config.get('default_quantization', ''),
73
+ 'is_loaded': self.is_loaded,
74
+ 'current_quantization': self.current_quantization
75
+ }
76
+
77
+ def get_supported_quantizations(self) -> List[str]:
78
+ """Get list of supported quantization methods."""
79
+ return self.model_config.get('supported_quantizations', [])
80
+
81
+ def get_memory_requirements(self) -> Dict[str, int]:
82
+ """Get memory requirements for the model."""
83
+ return {
84
+ 'min_gpu_memory_gb': self.model_config.get('min_gpu_memory_gb', 0),
85
+ 'recommended_gpu_memory_gb': self.model_config.get('recommended_gpu_memory_gb', 0)
86
+ }
87
+
88
+ def validate_quantization(self, quantization_type: str) -> bool:
89
+ """
90
+ Validate if the quantization type is supported.
91
+
92
+ Args:
93
+ quantization_type: Quantization type to validate
94
+
95
+ Returns:
96
+ True if supported, False otherwise
97
+ """
98
+ supported = self.get_supported_quantizations()
99
+ return quantization_type in supported
100
+
101
+ def __str__(self) -> str:
102
+ """String representation of the model."""
103
+ status = "loaded" if self.is_loaded else "not loaded"
104
+ quant = f" ({self.current_quantization})" if self.current_quantization else ""
105
+ return f"{self.model_name}{quant} - {status}"
106
+
107
+ def __repr__(self) -> str:
108
+ """Detailed string representation."""
109
+ return f"BaseModel(name={self.model_name}, loaded={self.is_loaded}, quantization={self.current_quantization})"
backend/models/internvl/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .internvl_model import InternVLModel
2
+
3
+ __all__ = ['InternVLModel']
backend/models/internvl/internvl_model.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+ import os
4
+ from typing import Dict, Any, Optional, Callable
5
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
6
+ from ..base_model import BaseModel
7
+ from ...utils.image_processing import load_image
8
+ from ...config.config_manager import ConfigManager
9
+
10
+
11
+ class InternVLModel(BaseModel):
12
+ """InternVL3 model implementation."""
13
+
14
+ def __init__(self, model_name: str, model_config: Dict[str, Any], config_manager: ConfigManager):
15
+ """
16
+ Initialize the InternVL model.
17
+
18
+ Args:
19
+ model_name: Name of the model
20
+ model_config: Configuration dictionary for the model
21
+ config_manager: Configuration manager instance
22
+ """
23
+ super().__init__(model_name, model_config)
24
+ self.config_manager = config_manager
25
+
26
+ # Set environment variable for CUDA memory allocation
27
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
28
+
29
+ def check_model_exists_locally(self) -> bool:
30
+ """Check if model exists locally in Hugging Face cache."""
31
+ try:
32
+ from transformers.utils import cached_file
33
+ cached_file(self.model_id, "config.json", local_files_only=True)
34
+ return True
35
+ except:
36
+ return False
37
+
38
+ def download_model_with_progress(self, progress_callback: Optional[Callable] = None) -> bool:
39
+ """
40
+ Download model with progress tracking.
41
+
42
+ Args:
43
+ progress_callback: Callback function for progress updates
44
+
45
+ Returns:
46
+ True if successful, False otherwise
47
+ """
48
+ try:
49
+ if progress_callback:
50
+ progress_callback("πŸ“₯ Downloading tokenizer...")
51
+
52
+ # Download tokenizer first (smaller)
53
+ tokenizer = AutoTokenizer.from_pretrained(
54
+ self.model_id,
55
+ trust_remote_code=True,
56
+ use_fast=False
57
+ )
58
+
59
+ if progress_callback:
60
+ progress_callback("πŸ“₯ Downloading model weights... This may take several minutes...")
61
+
62
+ # Download model config and weights
63
+ config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
64
+
65
+ if progress_callback:
66
+ progress_callback("βœ… Model downloaded successfully!")
67
+
68
+ return True
69
+ except Exception as e:
70
+ if progress_callback:
71
+ progress_callback(f"❌ Download failed: {str(e)}")
72
+ return False
73
+
74
+ def split_model(self) -> Dict[str, int]:
75
+ """
76
+ Distribute LLM layers across GPUs, keeping vision encoder on GPU 0.
77
+
78
+ Returns:
79
+ Device map dictionary
80
+ """
81
+ device_map = {}
82
+ world_size = torch.cuda.device_count()
83
+
84
+ if world_size < 2:
85
+ return "auto" # let transformers decide
86
+
87
+ cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
88
+ num_layers = cfg.llm_config.num_hidden_layers # type: ignore[attr-defined]
89
+
90
+ # More aggressive distribution - treat GPU 0 as 0.3 GPU capacity due to vision model
91
+ effective_gpus = world_size - 0.7 # More conservative for GPU 0
92
+ layers_per_gpu = num_layers / effective_gpus
93
+
94
+ # Calculate layer distribution
95
+ gpu_layers = []
96
+ for i in range(world_size):
97
+ if i == 0:
98
+ # GPU 0 gets fewer layers due to vision model
99
+ gpu_layers.append(max(1, int(layers_per_gpu * 0.3)))
100
+ else:
101
+ gpu_layers.append(int(layers_per_gpu))
102
+
103
+ # Adjust if total doesn't match num_layers
104
+ total_assigned = sum(gpu_layers)
105
+ diff = num_layers - total_assigned
106
+ if diff > 0:
107
+ # Add remaining layers to non-zero GPUs
108
+ for i in range(1, min(world_size, diff + 1)):
109
+ gpu_layers[i] += 1
110
+ elif diff < 0:
111
+ # Remove excess layers from GPU 0
112
+ gpu_layers[0] = max(1, gpu_layers[0] + diff)
113
+
114
+ # Assign layers to devices
115
+ layer_cnt = 0
116
+ for gpu_id, num_layers_on_gpu in enumerate(gpu_layers):
117
+ for _ in range(num_layers_on_gpu):
118
+ if layer_cnt < num_layers:
119
+ device_map[f'language_model.model.layers.{layer_cnt}'] = gpu_id
120
+ layer_cnt += 1
121
+
122
+ # Distribute other components more evenly across GPUs
123
+ last_gpu = world_size - 1
124
+
125
+ # Vision model must stay on GPU 0
126
+ device_map['vision_model'] = 0
127
+ device_map['mlp1'] = 0
128
+
129
+ # Distribute language model components across GPUs
130
+ device_map['language_model.model.tok_embeddings'] = 0
131
+ device_map['language_model.model.embed_tokens'] = 0
132
+ device_map['language_model.model.norm'] = last_gpu # Move to last GPU
133
+ device_map['language_model.model.rotary_emb'] = 1 if world_size > 1 else 0 # Move to GPU 1
134
+ device_map['language_model.output'] = last_gpu # Move to last GPU
135
+ device_map['language_model.lm_head'] = last_gpu # Move to last GPU
136
+
137
+ # Keep the last layer on the same GPU as output layers for compatibility
138
+ device_map[f'language_model.model.layers.{num_layers - 1}'] = last_gpu
139
+
140
+ print(f"Layer distribution: {gpu_layers}")
141
+ print(f"Total layers: {num_layers}, Assigned: {sum(gpu_layers)}")
142
+
143
+ return device_map
144
+
145
+ def load_model(self, quantization_type: str, progress_callback: Optional[Callable] = None) -> bool:
146
+ """
147
+ Load the model with specified quantization.
148
+
149
+ Args:
150
+ quantization_type: Type of quantization to use
151
+ progress_callback: Callback function for progress updates
152
+
153
+ Returns:
154
+ True if successful, False otherwise
155
+ """
156
+ if not self.validate_quantization(quantization_type):
157
+ raise ValueError(f"Quantization type '{quantization_type}' not supported for {self.model_name}")
158
+
159
+ # If model is already loaded with the same quantization, return
160
+ if (self.model is not None and self.tokenizer is not None and
161
+ self.current_quantization == quantization_type):
162
+ if progress_callback:
163
+ progress_callback(f"βœ… {self.model_name} already loaded!")
164
+ return True
165
+
166
+ print(f"Loading {self.model_name} with {quantization_type} quantization...")
167
+ if progress_callback:
168
+ progress_callback(f"πŸ”„ Loading {self.model_name} with {quantization_type} quantization...")
169
+
170
+ try:
171
+ # Check if model exists locally
172
+ model_exists = self.check_model_exists_locally()
173
+ if not model_exists:
174
+ if progress_callback:
175
+ progress_callback(f"πŸ“₯ {self.model_name} not found locally. Starting download...")
176
+ print(f"Model {self.model_name} not found locally. Starting download...")
177
+ success = self.download_model_with_progress(progress_callback)
178
+ if not success:
179
+ raise Exception(f"Failed to download {self.model_name}")
180
+ else:
181
+ if progress_callback:
182
+ progress_callback(f"βœ… {self.model_name} found locally.")
183
+
184
+ # Clear existing model if any
185
+ if self.model is not None:
186
+ self.unload_model()
187
+
188
+ # Print memory before loading
189
+ self._print_gpu_memory("before loading")
190
+
191
+ if progress_callback:
192
+ progress_callback(f"πŸš€ Loading {self.model_name} tokenizer...")
193
+
194
+ # Load tokenizer
195
+ self.tokenizer = AutoTokenizer.from_pretrained(
196
+ self.model_id,
197
+ trust_remote_code=True,
198
+ use_fast=False
199
+ )
200
+
201
+ # Load model based on quantization type
202
+ if "non-quantized" in quantization_type:
203
+ if progress_callback:
204
+ progress_callback(f"πŸš€ Loading {self.model_name} model in 16-bit precision...")
205
+
206
+ device_map = self.split_model()
207
+ print(f"Device map for multi-GPU: {device_map}")
208
+
209
+ # Try loading with custom device_map, fallback to "auto" if it fails
210
+ # Some InternVL models (e.g., InternVL3_5) don't support custom device_map
211
+ # due to missing 'all_tied_weights_keys' attribute
212
+ try:
213
+ self.model = AutoModel.from_pretrained(
214
+ self.model_id,
215
+ torch_dtype=torch.bfloat16,
216
+ low_cpu_mem_usage=True,
217
+ use_flash_attn=True,
218
+ trust_remote_code=True,
219
+ device_map=device_map,
220
+ ).eval()
221
+ except (AttributeError, TypeError, RuntimeError, ValueError) as e:
222
+ error_str = str(e).lower()
223
+ # Check for device_map related errors, especially all_tied_weights_keys
224
+ # This is a known issue with some InternVL models that don't expose
225
+ # the all_tied_weights_keys attribute required for custom device_map
226
+ if ("all_tied_weights_keys" in error_str or
227
+ "tied_weights" in error_str or
228
+ ("device_map" in error_str and "attribute" in error_str)):
229
+ print(f"⚠️ Custom device_map failed ({str(e)}), falling back to 'auto' device_map...")
230
+ if progress_callback:
231
+ progress_callback(f"⚠️ Using automatic device mapping...")
232
+ self.model = AutoModel.from_pretrained(
233
+ self.model_id,
234
+ torch_dtype=torch.bfloat16,
235
+ low_cpu_mem_usage=True,
236
+ use_flash_attn=True,
237
+ trust_remote_code=True,
238
+ device_map="auto",
239
+ ).eval()
240
+ else:
241
+ # Re-raise if it's a different error
242
+ raise
243
+ else: # quantized (8bit)
244
+ if progress_callback:
245
+ progress_callback(f"πŸš€ Loading {self.model_name} model with 8-bit quantization...")
246
+
247
+ print("Loading with 8-bit quantization to reduce memory usage...")
248
+ self.model = AutoModel.from_pretrained(
249
+ self.model_id,
250
+ torch_dtype=torch.bfloat16,
251
+ load_in_8bit=True,
252
+ low_cpu_mem_usage=True,
253
+ use_flash_attn=True,
254
+ trust_remote_code=True,
255
+ device_map="auto" # Let transformers handle device mapping for quantized model
256
+ ).eval()
257
+
258
+ # Verify model and tokenizer are properly loaded
259
+ if self.model is None:
260
+ raise Exception(f"Model failed to load for {self.model_name}")
261
+ if self.tokenizer is None:
262
+ raise Exception(f"Tokenizer failed to load for {self.model_name}")
263
+
264
+ self.current_quantization = quantization_type
265
+ self.is_loaded = True
266
+
267
+ success_msg = f"βœ… {self.model_name} loaded successfully with {quantization_type} quantization!"
268
+ print(success_msg)
269
+ if progress_callback:
270
+ progress_callback(success_msg)
271
+
272
+ # Print GPU memory usage after loading
273
+ self._print_gpu_memory("after loading")
274
+
275
+ return True
276
+
277
+ except Exception as e:
278
+ error_msg = f"Failed to load model {self.model_name}: {str(e)}"
279
+ print(error_msg)
280
+ if progress_callback:
281
+ progress_callback(f"❌ {error_msg}")
282
+
283
+ # Reset on failure
284
+ self.unload_model()
285
+ raise Exception(error_msg)
286
+
287
+ def unload_model(self) -> None:
288
+ """Unload the model from memory."""
289
+ if self.model is not None:
290
+ print("🧹 Clearing model from memory...")
291
+ del self.model
292
+ self.model = None
293
+
294
+ if self.tokenizer is not None:
295
+ del self.tokenizer
296
+ self.tokenizer = None
297
+
298
+ self.current_quantization = None
299
+ self.is_loaded = False
300
+
301
+ # Clear GPU cache
302
+ if torch.cuda.is_available():
303
+ torch.cuda.empty_cache()
304
+
305
+ # Force garbage collection
306
+ gc.collect()
307
+
308
+ if torch.cuda.is_available():
309
+ torch.cuda.empty_cache() # Clear again after gc
310
+
311
+ print("βœ… Model unloaded successfully")
312
+
313
+ def inference(self, image_path: str, prompt: str, **kwargs) -> str:
314
+ """
315
+ Perform inference on an image with a text prompt.
316
+
317
+ Args:
318
+ image_path: Path to the image file
319
+ prompt: Text prompt for the model
320
+ **kwargs: Additional inference parameters
321
+
322
+ Returns:
323
+ Model's text response
324
+ """
325
+ if not self.is_loaded:
326
+ raise RuntimeError(f"Model {self.model_name} is not loaded. Call load_model() first.")
327
+
328
+ try:
329
+ # Load and preprocess image using default settings from original app.py
330
+ pixel_values = load_image(image_path, input_size=448, max_num=12).to(torch.bfloat16)
331
+
332
+ # Move pixel_values to the same device as the model
333
+ if torch.cuda.is_available():
334
+ # Get the device of the first model parameter
335
+ model_device = next(self.model.parameters()).device
336
+ pixel_values = pixel_values.to(model_device)
337
+ else:
338
+ # Fallback to CPU if no CUDA available
339
+ pixel_values = pixel_values.cpu()
340
+
341
+ # Prepare prompt
342
+ formatted_prompt = f"<image>\n{prompt}" if prompt else "<image>\n"
343
+
344
+ # Generation configuration - using same settings as original app.py
345
+ gen_cfg = dict(max_new_tokens=1024, do_sample=True)
346
+
347
+ # Perform inference
348
+ response = self.model.chat(self.tokenizer, pixel_values, formatted_prompt, gen_cfg)
349
+ return response
350
+
351
+ except Exception as e:
352
+ error_msg = f"Error processing image: {str(e)}"
353
+ print(error_msg)
354
+ return error_msg
355
+
356
+ def _print_gpu_memory(self, stage: str) -> None:
357
+ """Print GPU memory usage for debugging."""
358
+ if torch.cuda.is_available():
359
+ print(f"Memory {stage}:")
360
+ for i in range(torch.cuda.device_count()):
361
+ allocated = torch.cuda.memory_allocated(i) / 1024**3
362
+ reserved = torch.cuda.memory_reserved(i) / 1024**3
363
+ print(f"GPU {i}: Allocated {allocated:.2f} GB, Reserved {reserved:.2f} GB")
backend/models/model_manager.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from typing import Dict, Any, Optional, Callable
3
+ from .base_model import BaseModel
4
+ from .internvl import InternVLModel
5
+ from .qwen import QwenModel
6
+ from ..config.config_manager import ConfigManager
7
+
8
+
9
+ class ModelManager:
10
+ """Manager class for handling multiple vision-language models."""
11
+
12
+ def __init__(self, config_manager: ConfigManager):
13
+ """
14
+ Initialize the model manager.
15
+
16
+ Args:
17
+ config_manager: Configuration manager instance
18
+ """
19
+ self.config_manager = config_manager
20
+ self.models: Dict[str, BaseModel] = {}
21
+ self.current_model: Optional[BaseModel] = None
22
+ self.current_model_name: Optional[str] = None
23
+ self.loading_lock = threading.Lock()
24
+
25
+ # Apply environment settings
26
+ self.config_manager.apply_environment_settings()
27
+
28
+ # Initialize models but don't load them yet
29
+ self._initialize_models()
30
+
31
+ def _get_model_class(self, model_config: Dict[str, Any]) -> type:
32
+ """
33
+ Determine the appropriate model class based on model configuration.
34
+
35
+ Args:
36
+ model_config: Model configuration dictionary
37
+
38
+ Returns:
39
+ Model class to instantiate
40
+ """
41
+ model_type = model_config.get('model_type', 'internvl').lower()
42
+ model_id = model_config.get('model_id', '').lower()
43
+
44
+ # Determine model type based on model_id or explicit model_type
45
+ if 'qwen' in model_id or model_type == 'qwen':
46
+ return QwenModel
47
+ elif 'internvl' in model_id or model_type == 'internvl':
48
+ return InternVLModel
49
+ else:
50
+ # Default to InternVL for backward compatibility
51
+ print(f"⚠️ Unknown model type for {model_config.get('name', 'unknown')}, defaulting to InternVL")
52
+ return InternVLModel
53
+
54
+ def _initialize_models(self) -> None:
55
+ """Initialize model instances without loading them."""
56
+ available_models = self.config_manager.get_available_models()
57
+
58
+ for model_name, model_id in available_models.items():
59
+ model_config = self.config_manager.get_model_config(model_name)
60
+
61
+ # Determine the appropriate model class
62
+ model_class = self._get_model_class(model_config)
63
+
64
+ # Create model instance
65
+ self.models[model_name] = model_class(
66
+ model_name=model_name,
67
+ model_config=model_config,
68
+ config_manager=self.config_manager
69
+ )
70
+
71
+ print(f"βœ… Initialized {model_class.__name__}: {model_name}")
72
+
73
+ def get_available_models(self) -> list[str]:
74
+ """Get list of available model names."""
75
+ return list(self.models.keys())
76
+
77
+ def get_model_info(self, model_name: str) -> Dict[str, Any]:
78
+ """
79
+ Get information about a specific model.
80
+
81
+ Args:
82
+ model_name: Name of the model
83
+
84
+ Returns:
85
+ Model information dictionary
86
+ """
87
+ if model_name not in self.models:
88
+ raise KeyError(f"Model '{model_name}' not available")
89
+
90
+ return self.models[model_name].get_model_info()
91
+
92
+ def get_all_models_info(self) -> Dict[str, Dict[str, Any]]:
93
+ """Get information about all available models."""
94
+ return {name: model.get_model_info() for name, model in self.models.items()}
95
+
96
+ def load_model(
97
+ self,
98
+ model_name: str,
99
+ quantization_type: str,
100
+ progress_callback: Optional[Callable] = None
101
+ ) -> bool:
102
+ """
103
+ Load a specific model with given quantization.
104
+
105
+ Args:
106
+ model_name: Name of the model to load
107
+ quantization_type: Type of quantization to use
108
+ progress_callback: Callback function for progress updates
109
+
110
+ Returns:
111
+ True if successful, False otherwise
112
+ """
113
+ with self.loading_lock:
114
+ if model_name not in self.models:
115
+ raise KeyError(f"Model '{model_name}' not available")
116
+
117
+ model = self.models[model_name]
118
+
119
+ # Check if this model is already loaded with the same quantization
120
+ if (self.current_model == model and
121
+ model.is_model_loaded() and
122
+ model.current_quantization == quantization_type):
123
+ if progress_callback:
124
+ progress_callback(f"βœ… {model_name} already loaded with {quantization_type}!")
125
+ return True
126
+
127
+ # Unload current model if different
128
+ if (self.current_model and
129
+ self.current_model != model and
130
+ self.current_model.is_model_loaded()):
131
+ if progress_callback:
132
+ progress_callback(f"πŸ”„ Unloading {self.current_model_name}...")
133
+ self.current_model.unload_model()
134
+
135
+ # Load the requested model
136
+ try:
137
+ success = model.load_model(quantization_type, progress_callback)
138
+ if success:
139
+ self.current_model = model
140
+ self.current_model_name = model_name
141
+ print(f"βœ… Successfully loaded {model_name} with {quantization_type}")
142
+ return True
143
+ else:
144
+ if progress_callback:
145
+ progress_callback(f"❌ Failed to load {model_name}")
146
+ return False
147
+ except Exception as e:
148
+ error_msg = f"Error loading {model_name}: {str(e)}"
149
+ print(error_msg)
150
+ if progress_callback:
151
+ progress_callback(f"❌ {error_msg}")
152
+ return False
153
+
154
+ def unload_current_model(self) -> None:
155
+ """Unload the currently loaded model."""
156
+ with self.loading_lock:
157
+ if self.current_model and self.current_model.is_model_loaded():
158
+ print(f"πŸ”„ Unloading {self.current_model_name}...")
159
+ self.current_model.unload_model()
160
+ self.current_model = None
161
+ self.current_model_name = None
162
+ print("βœ… Model unloaded successfully")
163
+ else:
164
+ print("ℹ️ No model currently loaded")
165
+
166
+ def inference(self, image_path: str, prompt: str, **kwargs) -> str:
167
+ """
168
+ Perform inference using the currently loaded model.
169
+
170
+ Args:
171
+ image_path: Path to the image file
172
+ prompt: Text prompt for the model
173
+ **kwargs: Additional inference parameters
174
+
175
+ Returns:
176
+ Model's text response
177
+ """
178
+ if not self.current_model or not self.current_model.is_model_loaded():
179
+ raise RuntimeError("No model is currently loaded. Load a model first.")
180
+
181
+ return self.current_model.inference(image_path, prompt, **kwargs)
182
+
183
+ def get_current_model_status(self) -> str:
184
+ """Get status string for the currently loaded model."""
185
+ if not self.current_model or not self.current_model.is_model_loaded():
186
+ return "❌ No model loaded"
187
+
188
+ quantization = self.current_model.current_quantization or "Unknown"
189
+ model_class = self.current_model.__class__.__name__
190
+ return f"βœ… {self.current_model_name} ({model_class}) loaded with {quantization}"
191
+
192
+ def get_supported_quantizations(self, model_name: str) -> list[str]:
193
+ """Get supported quantization methods for a model."""
194
+ if model_name not in self.models:
195
+ raise KeyError(f"Model '{model_name}' not available")
196
+
197
+ return self.models[model_name].get_supported_quantizations()
198
+
199
+ def validate_model_and_quantization(self, model_name: str, quantization_type: str) -> bool:
200
+ """
201
+ Validate if a model and quantization combination is valid.
202
+
203
+ Args:
204
+ model_name: Name of the model
205
+ quantization_type: Type of quantization
206
+
207
+ Returns:
208
+ True if valid, False otherwise
209
+ """
210
+ if model_name not in self.models:
211
+ return False
212
+
213
+ return self.models[model_name].validate_quantization(quantization_type)
214
+
215
+ def get_model_memory_requirements(self, model_name: str) -> Dict[str, int]:
216
+ """Get memory requirements for a specific model."""
217
+ if model_name not in self.models:
218
+ raise KeyError(f"Model '{model_name}' not available")
219
+
220
+ return self.models[model_name].get_memory_requirements()
221
+
222
+ def preload_default_model(self) -> bool:
223
+ """
224
+ Preload the default model specified in configuration.
225
+
226
+ Returns:
227
+ True if successful, False otherwise
228
+ """
229
+ default_model = self.config_manager.get_default_model()
230
+ default_quantization = self.config_manager.get_default_quantization(default_model)
231
+
232
+ print(f"πŸš€ Preloading default model: {default_model} with {default_quantization}")
233
+
234
+ try:
235
+ return self.load_model(default_model, default_quantization)
236
+ except Exception as e:
237
+ print(f"⚠️ Failed to preload default model: {str(e)}")
238
+ return False
239
+
240
+ def __str__(self) -> str:
241
+ """String representation of the model manager."""
242
+ loaded_info = f"Current: {self.current_model_name}" if self.current_model_name else "None loaded"
243
+ return f"ModelManager({len(self.models)} models available, {loaded_info})"
244
+
245
+ def __repr__(self) -> str:
246
+ """Detailed string representation."""
247
+ models_list = list(self.models.keys())
248
+ return f"ModelManager(models={models_list}, current={self.current_model_name})"
backend/models/qwen/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .qwen_model import QwenModel
2
+
3
+ __all__ = ['QwenModel']
backend/models/qwen/qwen_model.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+ import os
4
+ from typing import Dict, Any, Optional, Callable
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from ..base_model import BaseModel
7
+ from ...config.config_manager import ConfigManager
8
+
9
+
10
+ class QwenModel(BaseModel):
11
+ """Qwen2.5 model implementation."""
12
+
13
+ def __init__(self, model_name: str, model_config: Dict[str, Any], config_manager: ConfigManager):
14
+ """
15
+ Initialize the Qwen model.
16
+
17
+ Args:
18
+ model_name: Name of the model
19
+ model_config: Configuration dictionary for the model
20
+ config_manager: Configuration manager instance
21
+ """
22
+ super().__init__(model_name, model_config)
23
+ self.config_manager = config_manager
24
+
25
+ # Set environment variable for CUDA memory allocation
26
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
27
+
28
+ def check_model_exists_locally(self) -> bool:
29
+ """Check if model exists locally in Hugging Face cache."""
30
+ try:
31
+ from transformers.utils import cached_file
32
+ cached_file(self.model_id, "config.json", local_files_only=True)
33
+ return True
34
+ except:
35
+ return False
36
+
37
+ def download_model_with_progress(self, progress_callback: Optional[Callable] = None) -> bool:
38
+ """
39
+ Download model with progress tracking.
40
+
41
+ Args:
42
+ progress_callback: Callback function for progress updates
43
+
44
+ Returns:
45
+ True if successful, False otherwise
46
+ """
47
+ try:
48
+ if progress_callback:
49
+ progress_callback("πŸ“₯ Downloading tokenizer...")
50
+
51
+ # Download tokenizer first (smaller)
52
+ tokenizer = AutoTokenizer.from_pretrained(self.model_id)
53
+
54
+ if progress_callback:
55
+ progress_callback("πŸ“₯ Downloading model weights... This may take several minutes...")
56
+
57
+ # Download model config and weights by trying to load config
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ self.model_id,
60
+ torch_dtype="auto",
61
+ device_map="cpu", # Just download, don't load to GPU yet
62
+ low_cpu_mem_usage=True
63
+ )
64
+
65
+ # Clean up the test loading
66
+ del model
67
+
68
+ if progress_callback:
69
+ progress_callback("βœ… Model downloaded successfully!")
70
+
71
+ return True
72
+ except Exception as e:
73
+ if progress_callback:
74
+ progress_callback(f"❌ Download failed: {str(e)}")
75
+ return False
76
+
77
+ def load_model(self, quantization_type: str, progress_callback: Optional[Callable] = None) -> bool:
78
+ """
79
+ Load the model with specified quantization.
80
+
81
+ Args:
82
+ quantization_type: Type of quantization to use
83
+ progress_callback: Callback function for progress updates
84
+
85
+ Returns:
86
+ True if successful, False otherwise
87
+ """
88
+ if not self.validate_quantization(quantization_type):
89
+ raise ValueError(f"Quantization type '{quantization_type}' not supported for {self.model_name}")
90
+
91
+ # If model is already loaded with the same quantization, return
92
+ if (self.model is not None and self.tokenizer is not None and
93
+ self.current_quantization == quantization_type):
94
+ if progress_callback:
95
+ progress_callback(f"βœ… {self.model_name} already loaded!")
96
+ return True
97
+
98
+ print(f"Loading {self.model_name} with {quantization_type} quantization...")
99
+ if progress_callback:
100
+ progress_callback(f"πŸ”„ Loading {self.model_name} with {quantization_type} quantization...")
101
+
102
+ try:
103
+ # Check if model exists locally
104
+ model_exists = self.check_model_exists_locally()
105
+ if not model_exists:
106
+ if progress_callback:
107
+ progress_callback(f"πŸ“₯ {self.model_name} not found locally. Starting download...")
108
+ print(f"Model {self.model_name} not found locally. Starting download...")
109
+ success = self.download_model_with_progress(progress_callback)
110
+ if not success:
111
+ raise Exception(f"Failed to download {self.model_name}")
112
+ else:
113
+ if progress_callback:
114
+ progress_callback(f"βœ… {self.model_name} found locally.")
115
+
116
+ # Clear existing model if any
117
+ if self.model is not None:
118
+ self.unload_model()
119
+
120
+ # Print memory before loading
121
+ self._print_gpu_memory("before loading")
122
+
123
+ if progress_callback:
124
+ progress_callback(f"πŸš€ Loading {self.model_name} tokenizer...")
125
+
126
+ # Load tokenizer
127
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
128
+
129
+ # Load model based on quantization type
130
+ if progress_callback:
131
+ progress_callback(f"πŸš€ Loading {self.model_name} model...")
132
+
133
+ if "non-quantized" in quantization_type:
134
+ # Load with auto dtype and device mapping
135
+ self.model = AutoModelForCausalLM.from_pretrained(
136
+ self.model_id,
137
+ torch_dtype="auto",
138
+ device_map="auto",
139
+ low_cpu_mem_usage=True
140
+ )
141
+ else: # quantized (8bit)
142
+ print("Loading with 8-bit quantization to reduce memory usage...")
143
+ self.model = AutoModelForCausalLM.from_pretrained(
144
+ self.model_id,
145
+ torch_dtype="auto",
146
+ load_in_8bit=True,
147
+ device_map="auto",
148
+ low_cpu_mem_usage=True
149
+ )
150
+
151
+ # Verify model and tokenizer are properly loaded
152
+ if self.model is None:
153
+ raise Exception(f"Model failed to load for {self.model_name}")
154
+ if self.tokenizer is None:
155
+ raise Exception(f"Tokenizer failed to load for {self.model_name}")
156
+
157
+ self.current_quantization = quantization_type
158
+ self.is_loaded = True
159
+
160
+ success_msg = f"βœ… {self.model_name} loaded successfully with {quantization_type} quantization!"
161
+ print(success_msg)
162
+ if progress_callback:
163
+ progress_callback(success_msg)
164
+
165
+ # Print GPU memory usage after loading
166
+ self._print_gpu_memory("after loading")
167
+
168
+ return True
169
+
170
+ except Exception as e:
171
+ error_msg = f"Failed to load model {self.model_name}: {str(e)}"
172
+ print(error_msg)
173
+ if progress_callback:
174
+ progress_callback(f"❌ {error_msg}")
175
+
176
+ # Reset on failure
177
+ self.unload_model()
178
+ raise Exception(error_msg)
179
+
180
+ def unload_model(self) -> None:
181
+ """Unload the model from memory."""
182
+ if self.model is not None:
183
+ print("🧹 Clearing model from memory...")
184
+ del self.model
185
+ self.model = None
186
+
187
+ if self.tokenizer is not None:
188
+ del self.tokenizer
189
+ self.tokenizer = None
190
+
191
+ self.current_quantization = None
192
+ self.is_loaded = False
193
+
194
+ # Clear GPU cache
195
+ if torch.cuda.is_available():
196
+ torch.cuda.empty_cache()
197
+
198
+ # Force garbage collection
199
+ gc.collect()
200
+
201
+ if torch.cuda.is_available():
202
+ torch.cuda.empty_cache() # Clear again after gc
203
+
204
+ print("βœ… Model unloaded successfully")
205
+
206
+ def inference(self, image_path: str, prompt: str, **kwargs) -> str:
207
+ """
208
+ Perform inference with a text prompt.
209
+ Note: Qwen2.5 is a text-only model, so image_path is ignored.
210
+
211
+ Args:
212
+ image_path: Path to the image file (ignored for text-only models)
213
+ prompt: Text prompt for the model
214
+ **kwargs: Additional inference parameters
215
+
216
+ Returns:
217
+ Model's text response
218
+ """
219
+ if not self.is_loaded:
220
+ raise RuntimeError(f"Model {self.model_name} is not loaded. Call load_model() first.")
221
+
222
+ if not prompt or not prompt.strip():
223
+ return "Error: No prompt provided"
224
+
225
+ try:
226
+ # Prepare messages for chat format
227
+ messages = [
228
+ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
229
+ {"role": "user", "content": prompt}
230
+ ]
231
+
232
+ # Apply chat template
233
+ text = self.tokenizer.apply_chat_template(
234
+ messages,
235
+ tokenize=False,
236
+ add_generation_prompt=True
237
+ )
238
+
239
+ # Tokenize input
240
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
241
+
242
+ # Generate response
243
+ generated_ids = self.model.generate(
244
+ **model_inputs,
245
+ max_new_tokens=kwargs.get('max_new_tokens', 512),
246
+ do_sample=kwargs.get('do_sample', True),
247
+ temperature=kwargs.get('temperature', 0.7),
248
+ top_p=kwargs.get('top_p', 0.9),
249
+ pad_token_id=self.tokenizer.eos_token_id
250
+ )
251
+
252
+ # Extract only the generated part (remove input tokens)
253
+ generated_ids = [
254
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
255
+ ]
256
+
257
+ # Decode response
258
+ response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
259
+ return response
260
+
261
+ except Exception as e:
262
+ error_msg = f"Error processing prompt: {str(e)}"
263
+ print(error_msg)
264
+ return error_msg
265
+
266
+ def _print_gpu_memory(self, stage: str) -> None:
267
+ """Print GPU memory usage for debugging."""
268
+ if torch.cuda.is_available():
269
+ print(f"Memory {stage}:")
270
+ for i in range(torch.cuda.device_count()):
271
+ allocated = torch.cuda.memory_allocated(i) / 1024**3
272
+ reserved = torch.cuda.memory_reserved(i) / 1024**3
273
+ print(f"GPU {i}: Allocated {allocated:.2f} GB, Reserved {reserved:.2f} GB")
backend/utils/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .image_processing import (
2
+ build_transform,
3
+ find_closest_aspect_ratio,
4
+ dynamic_preprocess,
5
+ load_image
6
+ )
7
+ from .data_processing import (
8
+ extract_file_dict,
9
+ validate_data,
10
+ extract_binary_output
11
+ )
12
+ from .metrics import (
13
+ create_confusion_matrix_plot,
14
+ create_accuracy_table,
15
+ save_dataframe_to_csv
16
+ )
17
+
18
+ __all__ = [
19
+ 'build_transform',
20
+ 'find_closest_aspect_ratio',
21
+ 'dynamic_preprocess',
22
+ 'load_image',
23
+ 'extract_file_dict',
24
+ 'validate_data',
25
+ 'extract_binary_output',
26
+ 'create_confusion_matrix_plot',
27
+ 'create_accuracy_table',
28
+ 'save_dataframe_to_csv'
29
+ ]
backend/utils/data_processing.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Dict, List, Tuple, Union, Any
5
+
6
+ def extract_file_dict(folder_path: List[Path]) -> Dict[str, Path]:
7
+ """
8
+ Extract file dictionary from folder path.
9
+
10
+ Args:
11
+ folder_path: List of Path objects from Gradio file upload
12
+
13
+ Returns:
14
+ Dictionary mapping filename to full path
15
+ """
16
+ file_dict = {}
17
+ for file in folder_path:
18
+ filepath = file
19
+ filename = filepath.name.split("/")[-1]
20
+ file_dict[filename] = filepath
21
+ return file_dict
22
+
23
+
24
+ def validate_data(file_dict: Dict[str, Path]) -> Tuple[Union[bool, str], str]:
25
+ """
26
+ Validate the uploaded data structure.
27
+
28
+ Args:
29
+ file_dict: Dictionary of filename to path mappings
30
+
31
+ Returns:
32
+ Tuple of (validation_result, message)
33
+ validation_result can be:
34
+ - True: Valid data with CSV
35
+ - False: Invalid data
36
+ - "no_csv": Valid but no CSV file
37
+ - "multiple_csv": Valid but multiple CSV files
38
+ """
39
+ # Find CSV file
40
+ csv_files = [fname for fname in file_dict if fname.lower().endswith('.csv')]
41
+
42
+ # Find image files
43
+ image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']
44
+ image_files = [fname for fname in file_dict if any(fname.lower().endswith(ext) for ext in image_exts)]
45
+
46
+ if not image_files:
47
+ return False, "No image files found in the folder or subfolders"
48
+
49
+ # If no CSV or multiple CSVs, we'll proceed with file-based processing
50
+ if len(csv_files) == 0:
51
+ return "no_csv", "No CSV file found. Will extract data from file paths and names."
52
+ elif len(csv_files) > 1:
53
+ return "multiple_csv", "Multiple CSV files found. Will extract data from file paths and names."
54
+
55
+ # Check if single CSV has required columns
56
+ try:
57
+ df = pd.read_csv(file_dict[csv_files[0]])
58
+ if 'Ground Truth' not in df.columns:
59
+ return False, "CSV file does not contain 'Ground Truth' column"
60
+ if 'Image Name' not in df.columns:
61
+ return False, "CSV file does not contain 'Image Name' column"
62
+ except Exception as e:
63
+ return False, f"Error reading CSV file: {str(e)}"
64
+
65
+ return True, "Data validation successful"
66
+
67
+
68
+ def extract_binary_output(
69
+ model_output: str,
70
+ ground_truth: str = "",
71
+ all_ground_truths: List[str] = None
72
+ ) -> str:
73
+ """
74
+ Extract binary output from model response based on unique ground truth keywords.
75
+
76
+ Args:
77
+ model_output: The model's text response
78
+ ground_truth: Current item's ground truth (for fallback)
79
+ all_ground_truths: List of all ground truth values to extract unique keywords
80
+
81
+ Returns:
82
+ Extracted keyword that best matches the model output
83
+ """
84
+ if all_ground_truths is None:
85
+ all_ground_truths = []
86
+
87
+ # Unique lowercase keywords
88
+ unique_keywords = sorted({str(gt).strip().lower() for gt in all_ground_truths if gt})
89
+
90
+ # Take only the first line of model output
91
+ first_line = model_output.split("\n", 1)[0].lower()
92
+
93
+ print(f"DEBUG: Unique keywords extracted: {first_line}")
94
+ print(f"DEBUG: Model output: {model_output[:100]}...") # First 100 chars
95
+
96
+ for keyword in unique_keywords:
97
+ if keyword in first_line:
98
+ return keyword
99
+
100
+ return "Enter the output manually"
backend/utils/image_processing.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torchvision.transforms as T
5
+ from torchvision.transforms.functional import InterpolationMode
6
+ from typing import List, Tuple, Union
7
+
8
+ # Constants from InternVL preprocessing
9
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
10
+ IMAGENET_STD = (0.229, 0.224, 0.225)
11
+
12
+
13
+ def build_transform(input_size: int = 448) -> T.Compose:
14
+ """
15
+ Return torchvision transform matching InternVL pre‑training.
16
+
17
+ Args:
18
+ input_size: Input image size (default: 448)
19
+
20
+ Returns:
21
+ Composed torchvision transforms
22
+ """
23
+ return T.Compose([
24
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
25
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
26
+ T.ToTensor(),
27
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
28
+ ])
29
+
30
+
31
+ def find_closest_aspect_ratio(
32
+ aspect_ratio: float,
33
+ target_ratios: List[Tuple[int, int]],
34
+ width: int,
35
+ height: int,
36
+ image_size: int
37
+ ) -> Tuple[int, int]:
38
+ """
39
+ Find the closest aspect ratio from target ratios.
40
+
41
+ Args:
42
+ aspect_ratio: Current image aspect ratio
43
+ target_ratios: List of target aspect ratios as (width, height) tuples
44
+ width: Original image width
45
+ height: Original image height
46
+ image_size: Target image size
47
+
48
+ Returns:
49
+ Best matching aspect ratio as (width, height) tuple
50
+ """
51
+ best_ratio_diff = float("inf")
52
+ best_ratio = (1, 1)
53
+ area = width * height
54
+
55
+ for ratio in target_ratios:
56
+ tgt_ar = ratio[0] / ratio[1]
57
+ diff = abs(aspect_ratio - tgt_ar)
58
+
59
+ if (diff < best_ratio_diff or
60
+ (diff == best_ratio_diff and area > 0.5 * image_size * image_size * ratio[0] * ratio[1])):
61
+ best_ratio_diff = diff
62
+ best_ratio = ratio
63
+
64
+ return best_ratio
65
+
66
+
67
+ def dynamic_preprocess(
68
+ image: Image.Image,
69
+ min_num: int = 1,
70
+ max_num: int = 12,
71
+ image_size: int = 448,
72
+ use_thumbnail: bool = False
73
+ ) -> List[Image.Image]:
74
+ """
75
+ Split arbitrarily‑sized image into ≀12 tiles sized 448Γ—448 (InternVL spec).
76
+
77
+ Args:
78
+ image: Input PIL Image
79
+ min_num: Minimum number of tiles
80
+ max_num: Maximum number of tiles
81
+ image_size: Size of each tile (default: 448)
82
+ use_thumbnail: Whether to add a thumbnail version
83
+
84
+ Returns:
85
+ List of processed image tiles
86
+ """
87
+ ow, oh = image.size
88
+ aspect_ratio = ow / oh
89
+
90
+ # Generate target ratios
91
+ target_ratios = sorted(
92
+ {(i, j) for n in range(min_num, max_num + 1)
93
+ for i in range(1, n + 1)
94
+ for j in range(1, n + 1)
95
+ if min_num <= i * j <= max_num},
96
+ key=lambda x: x[0] * x[1],
97
+ )
98
+
99
+ ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, ow, oh, image_size)
100
+ tw, th = image_size * ratio[0], image_size * ratio[1]
101
+ blocks = ratio[0] * ratio[1]
102
+
103
+ resized = image.resize((tw, th))
104
+
105
+ # Create tiles
106
+ tiles = []
107
+ for idx in range(blocks):
108
+ tile = resized.crop((
109
+ (idx % (tw // image_size)) * image_size,
110
+ (idx // (tw // image_size)) * image_size,
111
+ ((idx % (tw // image_size)) + 1) * image_size,
112
+ ((idx // (tw // image_size)) + 1) * image_size,
113
+ ))
114
+ tiles.append(tile)
115
+
116
+ # Add thumbnail if requested and more than one tile
117
+ if use_thumbnail and blocks != 1:
118
+ tiles.append(image.resize((image_size, image_size)))
119
+
120
+ return tiles
121
+
122
+
123
+ def load_image(
124
+ path: str,
125
+ input_size: int = 448,
126
+ max_num: int = 12
127
+ ) -> torch.Tensor:
128
+ """
129
+ Load and preprocess image for InternVL model.
130
+
131
+ Args:
132
+ path: Path to the image file
133
+ input_size: Input image size (default: 448)
134
+ max_num: Maximum number of tiles (default: 12)
135
+
136
+ Returns:
137
+ Tensor of shape (N, 3, H, W) ready for InternVL
138
+ """
139
+ img = Image.open(path).convert("RGB")
140
+ transform = build_transform(input_size)
141
+ tiles = dynamic_preprocess(
142
+ img,
143
+ image_size=input_size,
144
+ use_thumbnail=True,
145
+ max_num=max_num
146
+ )
147
+ return torch.stack([transform(t) for t in tiles])
backend/utils/metrics.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import tempfile
6
+ from typing import Tuple, Optional
7
+ from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
8
+
9
+
10
+ def create_confusion_matrix_plot(
11
+ cm: np.ndarray,
12
+ accuracy: float,
13
+ labels: list = ['No', 'Yes']
14
+ ) -> str:
15
+ """
16
+ Create a confusion matrix plot and save it to a temporary file.
17
+
18
+ Args:
19
+ cm: Confusion matrix array
20
+ accuracy: Accuracy score
21
+ labels: Labels for the confusion matrix
22
+
23
+ Returns:
24
+ Path to the saved plot file
25
+ """
26
+ plt.figure(figsize=(6, 5))
27
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
28
+ plt.title(f'Confusion Matrix (Accuracy: {accuracy:.1%})')
29
+ plt.ylabel('Ground Truth')
30
+ plt.xlabel('Model Prediction')
31
+
32
+ temp_file = tempfile.mktemp(suffix='.png')
33
+ plt.savefig(temp_file, dpi=150, bbox_inches='tight')
34
+ plt.close()
35
+
36
+ return temp_file
37
+
38
+
39
+ def create_accuracy_table(df: pd.DataFrame) -> Tuple[pd.DataFrame, str, pd.DataFrame]:
40
+ """
41
+ Create accuracy metrics table and confusion matrix from results dataframe.
42
+
43
+ Args:
44
+ df: DataFrame with 'Ground Truth' and 'Binary Output' columns
45
+
46
+ Returns:
47
+ Tuple of (metrics_df, confusion_matrix_plot_path, confusion_matrix_values_df)
48
+
49
+ Raises:
50
+ ValueError: If insufficient data for binary classification
51
+ """
52
+ df_copy = df.copy()
53
+
54
+ # Get unique values from both Ground Truth and Binary Output
55
+ # Convert to string first, then apply .str operations
56
+ ground_truth_values = df_copy['Ground Truth'].dropna().astype(str).str.lower().unique()
57
+ binary_output_values = df_copy['Binary Output'].dropna().astype(str).str.lower().unique()
58
+
59
+ # Combine and get all unique values
60
+ all_values = set(list(ground_truth_values) + list(binary_output_values))
61
+ all_values = [v for v in all_values if v.strip()] # Remove empty strings
62
+
63
+ if len(all_values) < 2:
64
+ raise ValueError("Need at least 2 different values for binary classification")
65
+
66
+ # Sort values to ensure consistent mapping (alphabetical order)
67
+ sorted_values = sorted(all_values)
68
+
69
+ # Create mapping: first value (alphabetically) = 0, second = 1
70
+ # This ensures consistent mapping regardless of order in data
71
+ value_mapping = {sorted_values[0]: 0}
72
+ if len(sorted_values) >= 2:
73
+ value_mapping[sorted_values[1]] = 1
74
+
75
+ # If there are more than 2 values, map the rest to 1 (positive class)
76
+ for i in range(2, len(sorted_values)):
77
+ value_mapping[sorted_values[i]] = 1
78
+
79
+ print(f"Detected binary mapping: {value_mapping}")
80
+
81
+ # Apply mapping - convert to string first, then apply .str operations
82
+ df_copy['Ground Truth Binary'] = df_copy['Ground Truth'].astype(str).str.lower().map(value_mapping)
83
+ df_copy['Binary Output Binary'] = df_copy['Binary Output'].astype(str).str.lower().map(value_mapping)
84
+
85
+ # Remove rows where either ground truth or binary output is NaN
86
+ df_copy = df_copy.dropna(subset=['Ground Truth Binary', 'Binary Output Binary'])
87
+
88
+ if len(df_copy) == 0:
89
+ raise ValueError("No valid data for accuracy calculation after mapping. Check that Ground Truth and Binary Output contain valid binary values.")
90
+
91
+ # Calculate metrics
92
+ cm = confusion_matrix(df_copy['Ground Truth Binary'], df_copy['Binary Output Binary'])
93
+ accuracy = accuracy_score(df_copy['Ground Truth Binary'], df_copy['Binary Output Binary'])
94
+ precision = precision_score(df_copy['Ground Truth Binary'], df_copy['Binary Output Binary'], zero_division=0)
95
+ recall = recall_score(df_copy['Ground Truth Binary'], df_copy['Binary Output Binary'], zero_division=0)
96
+ f1 = f1_score(df_copy['Ground Truth Binary'], df_copy['Binary Output Binary'], zero_division=0)
97
+
98
+ # Create metrics dataframe
99
+ metrics_data = [
100
+ ["Accuracy", f"{accuracy:.3f}"],
101
+ ["Precision", f"{precision:.3f}"],
102
+ ["Recall", f"{recall:.3f}"],
103
+ ["F1 Score", f"{f1:.3f}"],
104
+ ["Total Samples", f"{len(df_copy)}"]
105
+ ]
106
+ metrics_df = pd.DataFrame(metrics_data, columns=["Metric", "Value"])
107
+
108
+ # Create labels for confusion matrix based on detected values
109
+ # Find the original case versions of the labels
110
+ original_labels = []
111
+ for mapped_val in sorted([k for k, v in value_mapping.items() if v in [0, 1]]):
112
+ # Find original case version from the data
113
+ original_case = None
114
+ for val in df_copy['Ground Truth'].dropna():
115
+ if str(val).lower() == mapped_val:
116
+ original_case = str(val)
117
+ break
118
+ if original_case is None:
119
+ for val in df_copy['Binary Output'].dropna():
120
+ if str(val).lower() == mapped_val:
121
+ original_case = str(val)
122
+ break
123
+ original_labels.append(original_case if original_case else mapped_val.title())
124
+
125
+ # Ensure we have exactly 2 labels
126
+ if len(original_labels) < 2:
127
+ original_labels = ['Class 0', 'Class 1']
128
+
129
+ cm_plot_path = create_confusion_matrix_plot(cm, accuracy, original_labels)
130
+
131
+ # Confusion matrix values table
132
+ if cm.shape == (2, 2):
133
+ tn, fp, fn, tp = cm.ravel()
134
+ cm_values = pd.DataFrame(
135
+ [[tn, fp], [fn, tp]],
136
+ columns=[f"Predicted {original_labels[0]}", f"Predicted {original_labels[1]}"],
137
+ index=[f"Actual {original_labels[0]}", f"Actual {original_labels[1]}"]
138
+ )
139
+ else:
140
+ cm_values = pd.DataFrame(cm)
141
+
142
+ return metrics_df, cm_plot_path, cm_values
143
+
144
+
145
+ def save_dataframe_to_csv(df: pd.DataFrame) -> Optional[str]:
146
+ """
147
+ Save dataframe to a temporary CSV file.
148
+
149
+ Args:
150
+ df: DataFrame to save
151
+
152
+ Returns:
153
+ Path to saved CSV file or None if failed
154
+ """
155
+ if df is None or df.empty:
156
+ return None
157
+
158
+ temp_file = tempfile.mktemp(suffix='.csv')
159
+ df.to_csv(temp_file, index=False)
160
+ return temp_file
config/models.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Configuration for Vision Language Models and Language Models
2
+ # This file contains model configurations for easy integration
3
+
4
+ models:
5
+ # InternVL Vision-Language Models
6
+ InternVL3-8B:
7
+ name: "InternVL3-8B"
8
+ model_id: "OpenGVLab/InternVL3-8B"
9
+ model_type: "internvl"
10
+ description: "Fastest model, good for quick processing"
11
+ supported_quantizations:
12
+ - "non-quantized(fp16)"
13
+ - "quantized(8bit)"
14
+ default_quantization: "non-quantized(fp16)"
15
+
16
+ InternVL3-14B:
17
+ name: "InternVL3-14B"
18
+ model_id: "OpenGVLab/InternVL3-14B"
19
+ model_type: "internvl"
20
+ description: "Balanced performance and quality"
21
+ supported_quantizations:
22
+ - "non-quantized(fp16)"
23
+ - "quantized(8bit)"
24
+ default_quantization: "quantized(8bit)"
25
+
26
+ InternVL3-38B:
27
+ name: "InternVL3-38B"
28
+ model_id: "OpenGVLab/InternVL3-38B"
29
+ model_type: "internvl"
30
+ description: "Highest quality, requires significant GPU memory"
31
+ supported_quantizations:
32
+ - "non-quantized(fp16)"
33
+ - "quantized(8bit)"
34
+ default_quantization: "quantized(8bit)"
35
+
36
+ InternVL3_5-8B:
37
+ name: "InternVL3_5-8B"
38
+ model_id: "OpenGVLab/InternVL3_5-8B"
39
+ model_type: "internvl"
40
+ description: "Fastest model, good for quick processing"
41
+ supported_quantizations:
42
+ - "non-quantized(fp16)"
43
+ - "quantized(8bit)"
44
+ default_quantization: "non-quantized(fp16)"
45
+
46
+ # Qwen Language Models (Text-only)
47
+ Qwen2.5-7B-Instruct:
48
+ name: "Qwen2.5-7B-Instruct"
49
+ model_id: "Qwen/Qwen2.5-7B-Instruct"
50
+ model_type: "qwen"
51
+ description: "Qwen2.5 7B instruction-tuned model for text generation"
52
+ supported_quantizations:
53
+ - "non-quantized(fp16)"
54
+ - "quantized(8bit)"
55
+ default_quantization: "quantized(8bit)"
56
+
57
+ Qwen2.5-14B-Instruct:
58
+ name: "Qwen2.5-14B-Instruct"
59
+ model_id: "Qwen/Qwen2.5-14B-Instruct"
60
+ model_type: "qwen"
61
+ description: "Qwen2.5 14B instruction-tuned model for better text generation"
62
+ supported_quantizations:
63
+ - "non-quantized(fp16)"
64
+ - "quantized(8bit)"
65
+ default_quantization: "quantized(8bit)"
66
+
67
+ # Default model selection
68
+ default_model: "InternVL3-8B"
debug_files.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ File Upload Diagnostic Script
4
+ This script helps debug why some images are not being processed.
5
+ """
6
+
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Dict, List
10
+
11
+ def analyze_uploaded_files(folder_path: str) -> None:
12
+ """
13
+ Analyze uploaded files to understand why some images might not be processed.
14
+
15
+ Args:
16
+ folder_path: Path to the uploaded folder
17
+ """
18
+ print("πŸ” File Upload Diagnostic Tool")
19
+ print("=" * 50)
20
+
21
+ if not os.path.exists(folder_path):
22
+ print(f"❌ Folder not found: {folder_path}")
23
+ return
24
+
25
+ # Get all files in the folder
26
+ all_files = []
27
+ for root, dirs, files in os.walk(folder_path):
28
+ for file in files:
29
+ full_path = os.path.join(root, file)
30
+ all_files.append(Path(full_path))
31
+
32
+ print(f"πŸ“ Total files found: {len(all_files)}")
33
+ print("\nπŸ“‹ All files:")
34
+ for i, file_path in enumerate(all_files, 1):
35
+ print(f" {i}. {file_path.name} (ext: {file_path.suffix.lower()})")
36
+
37
+ # Analyze image files
38
+ image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']
39
+ print(f"\nπŸ–ΌοΈ Looking for image extensions: {image_exts}")
40
+
41
+ image_files = []
42
+ non_image_files = []
43
+
44
+ for file_path in all_files:
45
+ if any(file_path.suffix.lower().endswith(ext) for ext in image_exts):
46
+ image_files.append(file_path)
47
+ else:
48
+ non_image_files.append(file_path)
49
+
50
+ print(f"\nβœ… Image files detected ({len(image_files)}):")
51
+ for i, img in enumerate(image_files, 1):
52
+ print(f" {i}. {img.name}")
53
+
54
+ print(f"\nπŸ“„ Non-image files ({len(non_image_files)}):")
55
+ for i, file in enumerate(non_image_files, 1):
56
+ print(f" {i}. {file.name} (ext: {file.suffix.lower()})")
57
+
58
+ # Check for CSV files
59
+ csv_files = [f for f in all_files if f.suffix.lower() == '.csv']
60
+ print(f"\nπŸ“Š CSV files found ({len(csv_files)}):")
61
+ for i, csv in enumerate(csv_files, 1):
62
+ print(f" {i}. {csv.name}")
63
+
64
+ # If CSV exists, check its content
65
+ if csv_files:
66
+ try:
67
+ import pandas as pd
68
+ df = pd.read_csv(csv_files[0])
69
+ print(f"\nπŸ“ˆ CSV Analysis for '{csv_files[0].name}':")
70
+ print(f" - Rows: {len(df)}")
71
+ print(f" - Columns: {list(df.columns)}")
72
+
73
+ if 'Image Name' in df.columns:
74
+ image_names_in_csv = df['Image Name'].tolist()
75
+ print(f" - Image names in CSV: {len(image_names_in_csv)}")
76
+
77
+ # Check which images from CSV actually exist as files
78
+ existing_images = []
79
+ missing_images = []
80
+
81
+ for img_name in image_names_in_csv:
82
+ if any(img.name == img_name for img in image_files):
83
+ existing_images.append(img_name)
84
+ else:
85
+ missing_images.append(img_name)
86
+
87
+ print(f"\nπŸ”— CSV-to-File Matching:")
88
+ print(f" - Images in CSV that exist as files: {len(existing_images)}")
89
+ print(f" - Images in CSV that are missing: {len(missing_images)}")
90
+
91
+ if existing_images:
92
+ print(" βœ… Matching files:")
93
+ for img in existing_images:
94
+ print(f" - {img}")
95
+
96
+ if missing_images:
97
+ print(" ❌ Missing files:")
98
+ for img in missing_images:
99
+ print(f" - {img}")
100
+
101
+ except Exception as e:
102
+ print(f" ❌ Error reading CSV: {e}")
103
+
104
+ # Summary
105
+ print(f"\nπŸ“Š SUMMARY:")
106
+ print(f" - Total files uploaded: {len(all_files)}")
107
+ print(f" - Image files detected: {len(image_files)}")
108
+ print(f" - CSV files: {len(csv_files)}")
109
+
110
+ if csv_files and 'df' in locals():
111
+ if 'Image Name' in df.columns:
112
+ print(f" - Images that will be processed: {len(existing_images)}")
113
+ else:
114
+ print(f" - CSV exists but no 'Image Name' column - will process all {len(image_files)} images")
115
+ else:
116
+ print(f" - No CSV - will process all {len(image_files)} images")
117
+
118
+ if __name__ == "__main__":
119
+ print("Please provide the path to your uploaded folder:")
120
+ folder_path = input("Folder path: ").strip()
121
+ analyze_uploaded_files(folder_path)
frontend/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .gradio_app import GradioApp
2
+
3
+ __all__ = ['GradioApp']
frontend/gradio_app.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import os
4
+ from pathlib import Path
5
+ import shutil
6
+ import tempfile
7
+ import uuid
8
+ import spaces
9
+ from typing import Optional
10
+
11
+ from backend import ConfigManager, ModelManager, InferenceEngine
12
+ from backend.utils.metrics import create_accuracy_table, save_dataframe_to_csv
13
+
14
+
15
+ class GradioApp:
16
+ """Gradio application for InternVL3 prompt engineering."""
17
+
18
+ def __init__(self):
19
+ """Initialize the Gradio application."""
20
+ # Initialize backend components
21
+ self.config_manager = ConfigManager()
22
+ self.model_manager = ModelManager(self.config_manager)
23
+ self.inference_engine = InferenceEngine(self.model_manager, self.config_manager)
24
+
25
+ # Try to preload default model
26
+ try:
27
+ self.model_manager.preload_default_model()
28
+ print("βœ… Default model preloaded successfully!")
29
+ except Exception as e:
30
+ print(f"⚠️ Default model preloading failed: {str(e)}")
31
+ print("The model will be loaded when first needed.")
32
+
33
+ def get_current_model_status(self) -> str:
34
+ """Get current model status for display."""
35
+ return self.model_manager.get_current_model_status()
36
+
37
+ def handle_stop_button(self):
38
+ """Handle stop button click."""
39
+ message = self.inference_engine.set_stop_flag()
40
+ return message, gr.update(visible=True)
41
+
42
+ def on_model_change(self, model_selection: str, quantization_type: str) -> str:
43
+ """Handle model/quantization dropdown changes."""
44
+ current_status = self.get_current_model_status()
45
+ if model_selection and quantization_type:
46
+ available_models = self.config_manager.get_available_models()
47
+ target_id = available_models.get(model_selection)
48
+ current_model_id = None
49
+ if self.model_manager.current_model:
50
+ current_model_id = self.model_manager.current_model.model_id
51
+
52
+ if (current_model_id != target_id or
53
+ (self.model_manager.current_model and
54
+ self.model_manager.current_model.current_quantization != quantization_type)):
55
+ return f"πŸ”„ Will load {model_selection} with {quantization_type} when processing starts"
56
+ return current_status
57
+
58
+ def get_model_choices_with_info(self) -> list[str]:
59
+ """Get model choices with type information for dropdown."""
60
+ choices = []
61
+ for model_name in self.config_manager.get_available_models().keys():
62
+ model_config = self.config_manager.get_model_config(model_name)
63
+ model_type = model_config.get('model_type', 'unknown').upper()
64
+ choices.append(f"{model_name} ({model_type})")
65
+ return choices
66
+
67
+ def extract_model_name_from_choice(self, choice: str) -> str:
68
+ """Extract the actual model name from the dropdown choice."""
69
+ return choice.split(' (')[0] if ' (' in choice else choice
70
+
71
+ def update_image_preview(self, evt: gr.SelectData, df, folder_path):
72
+ """Update image preview when table row is selected."""
73
+ if df is None or evt.index[0] >= len(df):
74
+ return None, ""
75
+ try:
76
+ # Use the full dataframe with image paths
77
+ full_df = getattr(self.inference_engine, 'full_df', None)
78
+ if full_df is None or evt.index[0] >= len(full_df):
79
+ return None, ""
80
+ selected_row = full_df.iloc[evt.index[0]]
81
+ image_path = selected_row["Image Path"]
82
+ model_output = selected_row["Model Output"]
83
+ if not os.path.exists(image_path):
84
+ return None, model_output
85
+ file_extension = Path(image_path).suffix
86
+ temp_filename = f"gradio_preview_{uuid.uuid4().hex}{file_extension}"
87
+ temp_path = os.path.join(tempfile.gettempdir(), temp_filename)
88
+ shutil.copy2(image_path, temp_path)
89
+ return temp_path, model_output
90
+ except Exception as e:
91
+ print(f"Error loading image preview: {e}")
92
+ return None, ""
93
+
94
+ def download_results_csv(self, results_table_data):
95
+ """Download results as CSV file."""
96
+ try:
97
+ print(f"Download function called with data type: {type(results_table_data)}")
98
+
99
+ if results_table_data is None:
100
+ print("No data to download")
101
+ return None
102
+
103
+ # Handle different data types from Gradio
104
+ if hasattr(results_table_data, 'values'):
105
+ # If it's a pandas DataFrame
106
+ df = results_table_data
107
+ elif isinstance(results_table_data, list):
108
+ # If it's a list of lists or list of dicts
109
+ if len(results_table_data) == 0:
110
+ print("Empty data")
111
+ return None
112
+ df = pd.DataFrame(results_table_data, columns=["S.No", "Image Name", "Ground Truth", "Binary Output", "Model Output"])
113
+ else:
114
+ # Try to convert to DataFrame
115
+ df = pd.DataFrame(results_table_data)
116
+
117
+ print(f"DataFrame shape: {df.shape}")
118
+ print(f"DataFrame columns: {df.columns.tolist()}")
119
+
120
+ # Create temporary file
121
+ temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False)
122
+ df.to_csv(temp_file.name, index=False)
123
+ temp_file.close()
124
+
125
+ print(f"CSV file created: {temp_file.name}")
126
+ return temp_file.name
127
+
128
+ except Exception as e:
129
+ print(f"Error in download_results_csv: {str(e)}")
130
+ import traceback
131
+ traceback.print_exc()
132
+ return None
133
+
134
+ def submit_and_show_metrics(self, df):
135
+ """Generate and show metrics for results."""
136
+ if df is None:
137
+ return df, df, None, None, None, gr.update(visible=False), gr.update(visible=False), ""
138
+
139
+ # Only create metrics if all outputs are valid yes/no responses
140
+ try:
141
+ metrics_df, cm_plot_path, cm_values = create_accuracy_table(df)
142
+ return df, df, metrics_df, cm_plot_path, cm_values, gr.update(visible=True), gr.update(visible=True), "πŸ“Š Metrics calculated successfully!"
143
+ except Exception as e:
144
+ print(f"Could not create metrics: {str(e)}")
145
+ return df, df, None, None, None, gr.update(visible=False), gr.update(visible=True), f"⚠️ Could not calculate metrics: {str(e)}"
146
+
147
+ @spaces.GPU
148
+ def process_input_ui(self, folder_path, prompt, quantization_type, model_selection):
149
+ """UI wrapper for processing input with progress updates."""
150
+ if not folder_path or not prompt.strip():
151
+ return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False),
152
+ "Please upload a folder and enter a prompt.", None, None, None,
153
+ gr.update(visible=False), gr.update(visible=False),
154
+ gr.update(value="⚠️ Please upload a folder and enter a prompt.", visible=True), "", gr.update(visible=False))
155
+
156
+ # Extract actual model name from the dropdown choice
157
+ actual_model_name = self.extract_model_name_from_choice(model_selection)
158
+
159
+ # Check if model needs to be downloaded and show progress
160
+ available_models = self.config_manager.get_available_models()
161
+ model_id = available_models[actual_model_name]
162
+
163
+ # Show processing message and hide stop status
164
+ yield (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
165
+ None, None, None, None,
166
+ gr.update(visible=False), gr.update(visible=False),
167
+ gr.update(value="πŸš€ Initializing processing...", visible=True), prompt, gr.update(visible=False))
168
+
169
+ # Process the input
170
+ error, show_results, show_image, table, error_message, final_message = self.inference_engine.process_folder_input(
171
+ folder_path, prompt, quantization_type, actual_model_name, gr.Progress()
172
+ )
173
+
174
+ # If error is visible, show results section but keep error visible
175
+ if error["visible"]:
176
+ yield (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True),
177
+ error, None, None, None,
178
+ gr.update(visible=False), gr.update(visible=False),
179
+ gr.update(value=final_message, visible=True), prompt, gr.update(visible=False))
180
+ else:
181
+ yield (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True),
182
+ None, show_results, show_image, table,
183
+ gr.update(visible=True), gr.update(visible=False),
184
+ gr.update(value=final_message, visible=True), prompt, gr.update(visible=False))
185
+
186
+ def rerun_ui(self, df, new_prompt, quantization_type, model_selection):
187
+ """UI wrapper for rerun with progress updates."""
188
+ if df is None or not new_prompt.strip():
189
+ return (df, None, None, None,
190
+ gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
191
+ gr.update(visible=False), gr.update(visible=True), "⚠️ Please provide a valid prompt", "")
192
+
193
+ # Extract actual model name from the dropdown choice
194
+ actual_model_name = self.extract_model_name_from_choice(model_selection)
195
+
196
+ # Hide all sections and show only processing, clear model output display
197
+ yield (df, None, None, None,
198
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
199
+ gr.update(visible=False), gr.update(visible=True), "πŸš€ Initializing reprocessing...", "Select a row from the table to see model output...")
200
+
201
+ # Process with new prompt
202
+ updated_df, accuracy_table_data, cm_plot, cm_values, section4_vis, progress_vis, final_message = self.inference_engine.rerun_with_new_prompt(
203
+ df, new_prompt, quantization_type, actual_model_name, gr.Progress()
204
+ )
205
+
206
+ # Show prompt editing and results sections again, show Generate Metrics button, hide progress, and clear model output display
207
+ yield (updated_df, accuracy_table_data, cm_plot, cm_values,
208
+ gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), section4_vis,
209
+ gr.update(visible=True), gr.update(visible=False), final_message, "Select a row from the table to see updated model output...")
210
+
211
+ def create_interface(self):
212
+ """Create and return the Gradio interface."""
213
+ # CSS from original app.py
214
+ css = """
215
+ .progress {
216
+ margin: 15px 0;
217
+ padding: 20px;
218
+ border-radius: 12px;
219
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
220
+ border: none;
221
+ color: white;
222
+ font-weight: 600;
223
+ font-size: 16px;
224
+ text-align: center;
225
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3);
226
+ animation: progressPulse 2s ease-in-out infinite alternate;
227
+ }
228
+
229
+ @keyframes progressPulse {
230
+ 0% {
231
+ transform: scale(1);
232
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3);
233
+ }
234
+ 100% {
235
+ transform: scale(1.02);
236
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.4);
237
+ }
238
+ }
239
+
240
+ .processing {
241
+ background: linear-gradient(45deg, #f0f9ff, #e3f2fd);
242
+ border: 2px solid #1976d2;
243
+ border-radius: 10px;
244
+ padding: 20px;
245
+ text-align: center;
246
+ margin: 10px 0;
247
+ }
248
+
249
+ .gr-button.processing {
250
+ background-color: #ffa726 !important;
251
+ color: white !important;
252
+ pointer-events: none;
253
+ }
254
+
255
+ /* Stop button styling */
256
+ .stop-button {
257
+ background: linear-gradient(135deg, #ff4757 0%, #c44569 100%) !important;
258
+ border: none !important;
259
+ color: white !important;
260
+ font-weight: 700 !important;
261
+ font-size: 16px !important;
262
+ box-shadow: 0 4px 15px rgba(255, 71, 87, 0.4) !important;
263
+ transition: all 0.3s ease !important;
264
+ }
265
+
266
+ .stop-button:hover {
267
+ transform: translateY(-2px) !important;
268
+ box-shadow: 0 8px 25px rgba(255, 71, 87, 0.6) !important;
269
+ background: linear-gradient(135deg, #ff3742 0%, #b83754 100%) !important;
270
+ }
271
+
272
+ .stop-status {
273
+ color: #ff4757;
274
+ font-weight: 600;
275
+ background: rgba(255, 71, 87, 0.1);
276
+ padding: 10px;
277
+ border-radius: 8px;
278
+ border-left: 4px solid #ff4757;
279
+ margin: 10px 0;
280
+ }
281
+
282
+ /* Enhanced button styling */
283
+ .gr-button {
284
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
285
+ border: none;
286
+ border-radius: 8px;
287
+ color: white;
288
+ font-weight: 600;
289
+ transition: all 0.3s ease;
290
+ }
291
+
292
+ .gr-button:hover {
293
+ transform: translateY(-2px);
294
+ box-shadow: 0 8px 25px rgba(102, 126, 234, 0.4);
295
+ }
296
+ """
297
+
298
+ with gr.Blocks(theme="origin", css=css) as demo:
299
+ gr.Markdown("""
300
+ <h1 style='text-align:center; color:#1976d2; font-size:2.5em; font-weight:bold; margin-bottom:40px!important;'>PROMPT_PILOT</h1>
301
+ <p style='text-align:center; color:#666; font-size:1.1em; margin-bottom:30px;'>
302
+ πŸ€– AI-powered analysis with different vision models
303
+ </p>
304
+ <h2 style='text-align:center; color:#666; font-size:1.1em; margin-bottom:30px;'>
305
+ Note: Currently Accuracy only works properly in case of binary output. For other cases kindly download the csv and calculate the accuracy separately.
306
+ </h2>
307
+ """, elem_id="main-title")
308
+
309
+ # Model and Quantization selection dropdowns at the top
310
+ model_choices = self.get_model_choices_with_info()
311
+ default_choice = f"{self.config_manager.get_default_model()} (INTERNVL)"
312
+
313
+ with gr.Row():
314
+ model_dropdown = gr.Dropdown(
315
+ choices=model_choices,
316
+ value=default_choice,
317
+ label="πŸ€– Model Selection",
318
+ info="Select model: InternVL (vision+text), Qwen (text-only)",
319
+ elem_id="model-dropdown"
320
+ )
321
+ quantization_dropdown = gr.Dropdown(
322
+ choices=["quantized(8bit)", "non-quantized(fp16)"],
323
+ value="non-quantized(fp16)",
324
+ label="πŸ”§ Model Quantization",
325
+ info="Select quantization type: quantized (8bit) uses less memory, non-quantized (fp16) for better quality",
326
+ elem_id="quantization-dropdown"
327
+ )
328
+
329
+ # Model status indicator
330
+ with gr.Row():
331
+ model_status = gr.Markdown(
332
+ value=self.get_current_model_status(),
333
+ label="Model Status",
334
+ elem_classes=["model-status"]
335
+ )
336
+
337
+ # Stop button row
338
+ with gr.Row():
339
+ stop_btn = gr.Button("πŸ›‘ STOP PROCESSING", variant="stop", size="lg", elem_classes=["stop-button"])
340
+ stop_status = gr.Markdown("", elem_classes=["stop-status"], visible=False)
341
+
342
+ with gr.Row(visible=True) as section1_row:
343
+ with gr.Column():
344
+ folder_input = gr.File(
345
+ label="Upload Folder",
346
+ file_count="directory",
347
+ type="filepath"
348
+ )
349
+ with gr.Column():
350
+ prompt_input = gr.Textbox(
351
+ label="Enter your prompt here",
352
+ placeholder="Type your prompt...",
353
+ lines=3
354
+ )
355
+ with gr.Column():
356
+ submit_btn = gr.Button("Proceed", variant="primary")
357
+
358
+ # Progress indicator for section 1
359
+ with gr.Row(visible=True) as section1_progress_row:
360
+ section1_progress_message = gr.Markdown("", elem_classes=["progress"], visible=False)
361
+
362
+ # Section 2: Edit Prompt and Rerun Controls (separate section)
363
+ with gr.Row(visible=False) as section2_prompt_row:
364
+ with gr.Column():
365
+ with gr.Row():
366
+ prompt_input_section2 = gr.Textbox(
367
+ label="Edit Prompt",
368
+ placeholder="Modify your prompt here...",
369
+ lines=2,
370
+ scale=4
371
+ )
372
+ rerun_btn = gr.Button("πŸ”„ Rerun", variant="secondary", size="lg", scale=1)
373
+
374
+ # Section 3: Results Display
375
+ with gr.Row(visible=False) as section3_results_row:
376
+ error_message = gr.Textbox(label="Error Message", visible=False)
377
+ with gr.Column(scale=1):
378
+ image_preview = gr.Image(label="Selected Image", height=270, width=480)
379
+ model_output_display = gr.Textbox(
380
+ label="Model Output for Selected Image",
381
+ placeholder="Select a row from the table to see model output...",
382
+ interactive=False,
383
+ lines=3
384
+ )
385
+ with gr.Column(scale=2):
386
+ with gr.Row():
387
+ gr.HTML("") # Empty space to push button to right
388
+ download_results_btn = gr.Button("πŸ“₯ CSV", size="sm", scale=1)
389
+ results_csv_output = gr.File(label="", visible=True, scale=1, show_label=False)
390
+ results_table = gr.Dataframe(
391
+ headers=["S.No", "Image Name", "Ground Truth", "Binary Output", "Model Output"],
392
+ label="Results",
393
+ interactive=True, # Make it editable for ground truth input
394
+ col_count=(5, "fixed")
395
+ )
396
+
397
+ # Generate Metrics button
398
+ with gr.Row(visible=False) as section3_submit_row:
399
+ with gr.Column():
400
+ submit_results_btn = gr.Button("Generate Metrics", variant="primary", size="lg")
401
+
402
+ # Progress indicator row
403
+ with gr.Row(visible=False) as progress_row:
404
+ progress_message = gr.Markdown("", elem_classes=["progress"])
405
+
406
+ # Section 4: Metrics and confusion matrix
407
+ with gr.Row(visible=False) as section4_metrics_row:
408
+ with gr.Column(scale=2):
409
+ confusion_matrix_plot = gr.Image(
410
+ label="Confusion Matrix"
411
+ )
412
+ with gr.Column(scale=2):
413
+ accuracy_table = gr.Dataframe(
414
+ label="Performance Metrics",
415
+ interactive=False
416
+ )
417
+ confusion_matrix_table = gr.Dataframe(
418
+ label="Confusion Matrix Table",
419
+ interactive=False
420
+ )
421
+
422
+ # State to store folder path
423
+ folder_path_state = gr.State()
424
+ folder_input.change(
425
+ fn=lambda x: x,
426
+ inputs=[folder_input],
427
+ outputs=[folder_path_state]
428
+ )
429
+
430
+ # Event handlers
431
+ submit_btn.click(
432
+ fn=self.process_input_ui,
433
+ inputs=[folder_input, prompt_input, quantization_dropdown, model_dropdown],
434
+ outputs=[section1_row, section2_prompt_row, section3_results_row, error_message, results_table, image_preview, results_table, section3_submit_row, section4_metrics_row, section1_progress_message, prompt_input_section2, stop_status]
435
+ )
436
+
437
+ results_table.select(
438
+ fn=self.update_image_preview,
439
+ inputs=[results_table, folder_path_state],
440
+ outputs=[image_preview, model_output_display]
441
+ )
442
+
443
+ submit_results_btn.click(
444
+ fn=self.submit_and_show_metrics,
445
+ inputs=[results_table],
446
+ outputs=[results_table, results_table, accuracy_table, confusion_matrix_plot, confusion_matrix_table, section4_metrics_row, progress_row, progress_message]
447
+ )
448
+
449
+ download_results_btn.click(
450
+ fn=self.download_results_csv,
451
+ inputs=[results_table],
452
+ outputs=[results_csv_output]
453
+ )
454
+
455
+ rerun_btn.click(
456
+ fn=self.rerun_ui,
457
+ inputs=[results_table, prompt_input_section2, quantization_dropdown, model_dropdown],
458
+ outputs=[results_table, accuracy_table, confusion_matrix_plot, confusion_matrix_table,
459
+ section1_row, section2_prompt_row, section3_results_row, section4_metrics_row, section3_submit_row, progress_row, progress_message, model_output_display]
460
+ )
461
+
462
+ # Model change handler to update status
463
+ model_dropdown.change(
464
+ fn=self.on_model_change,
465
+ inputs=[model_dropdown, quantization_dropdown],
466
+ outputs=[model_status]
467
+ )
468
+
469
+ quantization_dropdown.change(
470
+ fn=self.on_model_change,
471
+ inputs=[model_dropdown, quantization_dropdown],
472
+ outputs=[model_status]
473
+ )
474
+
475
+ # Stop button click handler
476
+ stop_btn.click(
477
+ fn=self.handle_stop_button,
478
+ inputs=[],
479
+ outputs=[stop_status, stop_status]
480
+ )
481
+
482
+ return demo
483
+
484
+ def launch(self, **kwargs):
485
+ """Launch the Gradio application."""
486
+ demo = self.create_interface()
487
+ return demo.launch(**kwargs)
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ Pillow
3
+ Requests
4
+ torch
5
+ torchvision
6
+ decord
7
+ git+https://github.com/huggingface/transformers.git
8
+ accelerate
9
+ einops
10
+ timm
11
+ sentencepiece
12
+ gradio>=4.19.2
13
+ torch>=2.2.0
14
+ torchvision>=0.17.0
15
+ transformers>=4.37.2
16
+ pillow>=10.2.0
17
+ accelerate>=0.27.2
18
+ bitsandbytes>=0.42.0
19
+ pandas>=1.5.0
20
+ matplotlib>=3.5.0
21
+ seaborn>=0.11.0
22
+ scikit-learn>=1.0.0
23
+ pyyaml>=6.0.0
24
+ spaces
25
+ boto3