Spaces:
Runtime error
Runtime error
Commit
Β·
1314bf5
1
Parent(s):
14ebc7f
First commit
Browse files- README copy.md +13 -0
- app.py +12 -0
- backend/__init__.py +27 -0
- backend/config/__init__.py +3 -0
- backend/config/config_manager.py +120 -0
- backend/inference/__init__.py +3 -0
- backend/inference/inference_engine.py +543 -0
- backend/models/__init__.py +6 -0
- backend/models/base_model.py +109 -0
- backend/models/internvl/__init__.py +3 -0
- backend/models/internvl/internvl_model.py +363 -0
- backend/models/model_manager.py +248 -0
- backend/models/qwen/__init__.py +3 -0
- backend/models/qwen/qwen_model.py +273 -0
- backend/utils/__init__.py +29 -0
- backend/utils/data_processing.py +100 -0
- backend/utils/image_processing.py +147 -0
- backend/utils/metrics.py +160 -0
- config/models.yaml +68 -0
- debug_files.py +121 -0
- frontend/__init__.py +3 -0
- frontend/gradio_app.py +487 -0
- requirements.txt +25 -0
README copy.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Prompt Pilot
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.34.2
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
short_description: test
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
InternVL3 Prompt Engineering Application
|
| 4 |
+
Entry point for the modular InternVL3 image analysis application.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from frontend.gradio_app import GradioApp
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
# Create and launch the application
|
| 11 |
+
app = GradioApp()
|
| 12 |
+
app.launch()
|
backend/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config import ConfigManager
|
| 2 |
+
from .models import ModelManager, InternVLModel, BaseModel
|
| 3 |
+
from .inference import InferenceEngine
|
| 4 |
+
from .utils import (
|
| 5 |
+
build_transform,
|
| 6 |
+
load_image,
|
| 7 |
+
extract_file_dict,
|
| 8 |
+
validate_data,
|
| 9 |
+
extract_binary_output,
|
| 10 |
+
create_accuracy_table,
|
| 11 |
+
save_dataframe_to_csv
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
'ConfigManager',
|
| 16 |
+
'ModelManager',
|
| 17 |
+
'InternVLModel',
|
| 18 |
+
'BaseModel',
|
| 19 |
+
'InferenceEngine',
|
| 20 |
+
'build_transform',
|
| 21 |
+
'load_image',
|
| 22 |
+
'extract_file_dict',
|
| 23 |
+
'validate_data',
|
| 24 |
+
'extract_binary_output',
|
| 25 |
+
'create_accuracy_table',
|
| 26 |
+
'save_dataframe_to_csv'
|
| 27 |
+
]
|
backend/config/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config_manager import ConfigManager
|
| 2 |
+
|
| 3 |
+
__all__ = ['ConfigManager']
|
backend/config/config_manager.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Dict, List, Any, Optional
|
| 5 |
+
|
| 6 |
+
class ConfigManager:
|
| 7 |
+
"""Manages configuration loading and access for the application."""
|
| 8 |
+
|
| 9 |
+
def __init__(self, config_path: Optional[str] = None):
|
| 10 |
+
"""
|
| 11 |
+
Initialize the configuration manager.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
config_path: Path to the configuration file. If None, uses default path.
|
| 15 |
+
"""
|
| 16 |
+
if config_path is None:
|
| 17 |
+
# Default to config/models.yaml relative to project root
|
| 18 |
+
project_root = Path(__file__).parent.parent.parent
|
| 19 |
+
config_path = project_root / "config" / "models.yaml"
|
| 20 |
+
|
| 21 |
+
self.config_path = Path(config_path)
|
| 22 |
+
self._config = None
|
| 23 |
+
self.load_config()
|
| 24 |
+
|
| 25 |
+
def load_config(self) -> None:
|
| 26 |
+
"""Load configuration from YAML file."""
|
| 27 |
+
try:
|
| 28 |
+
with open(self.config_path, 'r', encoding='utf-8') as file:
|
| 29 |
+
self._config = yaml.safe_load(file)
|
| 30 |
+
print(f"β
Configuration loaded from {self.config_path}")
|
| 31 |
+
except FileNotFoundError:
|
| 32 |
+
raise FileNotFoundError(f"Configuration file not found: {self.config_path}")
|
| 33 |
+
except yaml.YAMLError as e:
|
| 34 |
+
raise ValueError(f"Invalid YAML in configuration file: {e}")
|
| 35 |
+
|
| 36 |
+
def reload_config(self) -> None:
|
| 37 |
+
"""Reload configuration from file."""
|
| 38 |
+
self.load_config()
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def config(self) -> Dict[str, Any]:
|
| 42 |
+
"""Get the full configuration dictionary."""
|
| 43 |
+
if self._config is None:
|
| 44 |
+
self.load_config()
|
| 45 |
+
return self._config
|
| 46 |
+
|
| 47 |
+
def get_available_models(self) -> Dict[str, str]:
|
| 48 |
+
"""Get a dictionary of available model names and their IDs."""
|
| 49 |
+
models = self.config.get('models', {})
|
| 50 |
+
return {name: model_config['model_id'] for name, model_config in models.items()}
|
| 51 |
+
|
| 52 |
+
def get_model_config(self, model_name: str) -> Dict[str, Any]:
|
| 53 |
+
"""
|
| 54 |
+
Get configuration for a specific model.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
model_name: Name of the model (e.g., 'InternVL3-8B')
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Model configuration dictionary
|
| 61 |
+
|
| 62 |
+
Raises:
|
| 63 |
+
KeyError: If model name is not found
|
| 64 |
+
"""
|
| 65 |
+
models = self.config.get('models', {})
|
| 66 |
+
if model_name not in models:
|
| 67 |
+
available = list(models.keys())
|
| 68 |
+
raise KeyError(f"Model '{model_name}' not found. Available models: {available}")
|
| 69 |
+
|
| 70 |
+
return models[model_name]
|
| 71 |
+
|
| 72 |
+
def get_supported_quantizations(self, model_name: str) -> List[str]:
|
| 73 |
+
"""Get supported quantization methods for a model."""
|
| 74 |
+
model_config = self.get_model_config(model_name)
|
| 75 |
+
return model_config.get('supported_quantizations', [])
|
| 76 |
+
|
| 77 |
+
def get_default_quantization(self, model_name: str) -> str:
|
| 78 |
+
"""Get the default quantization method for a model."""
|
| 79 |
+
model_config = self.get_model_config(model_name)
|
| 80 |
+
return model_config.get('default_quantization', 'non-quantized(fp16)')
|
| 81 |
+
|
| 82 |
+
def get_default_model(self) -> str:
|
| 83 |
+
"""Get the default model name."""
|
| 84 |
+
return self.config.get('default_model', 'InternVL3-8B')
|
| 85 |
+
|
| 86 |
+
def validate_model_and_quantization(self, model_name: str, quantization: str) -> bool:
|
| 87 |
+
"""
|
| 88 |
+
Validate if a quantization method is supported for a model.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
model_name: Name of the model
|
| 92 |
+
quantization: Quantization method
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
True if valid, False otherwise
|
| 96 |
+
"""
|
| 97 |
+
try:
|
| 98 |
+
supported = self.get_supported_quantizations(model_name)
|
| 99 |
+
return quantization in supported
|
| 100 |
+
except KeyError:
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
def apply_environment_settings(self) -> None:
|
| 104 |
+
"""Apply environment settings to the current process."""
|
| 105 |
+
# Set CUDA memory allocation
|
| 106 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
| 107 |
+
|
| 108 |
+
def get_model_description(self, model_name: str) -> str:
|
| 109 |
+
"""Get description for a model."""
|
| 110 |
+
model_config = self.get_model_config(model_name)
|
| 111 |
+
return model_config.get('description', 'No description available')
|
| 112 |
+
|
| 113 |
+
def __str__(self) -> str:
|
| 114 |
+
"""String representation of the configuration manager."""
|
| 115 |
+
return f"ConfigManager(config_path={self.config_path})"
|
| 116 |
+
|
| 117 |
+
def __repr__(self) -> str:
|
| 118 |
+
"""Detailed string representation."""
|
| 119 |
+
models = list(self.get_available_models().keys())
|
| 120 |
+
return f"ConfigManager(config_path={self.config_path}, models={models})"
|
backend/inference/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .inference_engine import InferenceEngine
|
| 2 |
+
|
| 3 |
+
__all__ = ['InferenceEngine']
|
backend/inference/inference_engine.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import threading
|
| 3 |
+
import time
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, List, Tuple, Union, Any, Optional, Callable
|
| 7 |
+
import gradio as gr
|
| 8 |
+
from ..models.model_manager import ModelManager
|
| 9 |
+
from ..utils.data_processing import extract_file_dict, validate_data, extract_binary_output
|
| 10 |
+
from ..config.config_manager import ConfigManager
|
| 11 |
+
from ..utils.metrics import create_accuracy_table
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import boto3
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class InferenceEngine:
|
| 17 |
+
"""Engine for handling batch inference and processing control."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, model_manager: ModelManager, config_manager: ConfigManager):
|
| 20 |
+
"""
|
| 21 |
+
Initialize the inference engine.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
model_manager: Model manager instance
|
| 25 |
+
config_manager: Configuration manager instance
|
| 26 |
+
"""
|
| 27 |
+
self.model_manager = model_manager
|
| 28 |
+
self.config_manager = config_manager
|
| 29 |
+
self.processing_lock = threading.Lock()
|
| 30 |
+
self.stop_processing = False
|
| 31 |
+
self.full_df = None # Store full dataframe with image paths
|
| 32 |
+
|
| 33 |
+
def set_stop_flag(self) -> str:
|
| 34 |
+
"""Set the global stop flag to interrupt processing."""
|
| 35 |
+
with self.processing_lock:
|
| 36 |
+
self.stop_processing = True
|
| 37 |
+
print("π Stop signal received. Processing will halt after current image...")
|
| 38 |
+
return "π Stopping process... Please wait for current image to complete."
|
| 39 |
+
|
| 40 |
+
def reset_stop_flag(self) -> None:
|
| 41 |
+
"""Reset the global stop flag before starting new processing."""
|
| 42 |
+
with self.processing_lock:
|
| 43 |
+
self.stop_processing = False
|
| 44 |
+
|
| 45 |
+
def check_stop_flag(self) -> bool:
|
| 46 |
+
"""Check if processing should be stopped."""
|
| 47 |
+
with self.processing_lock:
|
| 48 |
+
return self.stop_processing
|
| 49 |
+
|
| 50 |
+
def _should_load_model(self, model_selection: str, quantization_type: str) -> bool:
|
| 51 |
+
"""
|
| 52 |
+
Check if we need to load the model.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
model_selection: Selected model name
|
| 56 |
+
quantization_type: Selected quantization type
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
True if model needs to be loaded, False otherwise
|
| 60 |
+
"""
|
| 61 |
+
# If no model is loaded, we need to load
|
| 62 |
+
if not self.model_manager.current_model or not self.model_manager.current_model.is_model_loaded():
|
| 63 |
+
return True
|
| 64 |
+
|
| 65 |
+
# If different model is selected, we need to load
|
| 66 |
+
if self.model_manager.current_model_name != model_selection:
|
| 67 |
+
return True
|
| 68 |
+
|
| 69 |
+
# If same model but different quantization, we need to reload
|
| 70 |
+
if self.model_manager.current_model.current_quantization != quantization_type:
|
| 71 |
+
return True
|
| 72 |
+
|
| 73 |
+
return False
|
| 74 |
+
|
| 75 |
+
def _ensure_correct_model_loaded(self, model_selection: str, quantization_type: str, progress: gr.Progress()) -> None:
|
| 76 |
+
"""
|
| 77 |
+
Ensure the correct model with correct quantization is loaded.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
model_selection: Selected model name
|
| 81 |
+
quantization_type: Selected quantization type
|
| 82 |
+
progress: Gradio progress object
|
| 83 |
+
"""
|
| 84 |
+
if self._should_load_model(model_selection, quantization_type):
|
| 85 |
+
progress(0, desc=f"π Loading {model_selection} ({quantization_type})...")
|
| 86 |
+
print(f"π Loading {model_selection} with {quantization_type}...")
|
| 87 |
+
success = self.model_manager.load_model(model_selection, quantization_type)
|
| 88 |
+
if not success:
|
| 89 |
+
raise Exception(f"Failed to load model {model_selection} with {quantization_type}")
|
| 90 |
+
else:
|
| 91 |
+
print(f"β
Correct model already loaded: {model_selection} with {quantization_type}")
|
| 92 |
+
|
| 93 |
+
def process_folder_input(
|
| 94 |
+
self,
|
| 95 |
+
folder_path: List[Path],
|
| 96 |
+
prompt: str,
|
| 97 |
+
quantization_type: str,
|
| 98 |
+
model_selection: str,
|
| 99 |
+
progress: gr.Progress()
|
| 100 |
+
) -> Tuple[Any, ...]:
|
| 101 |
+
"""
|
| 102 |
+
Process input folder with images and optional CSV.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
folder_path: List of Path objects from Gradio
|
| 106 |
+
prompt: Text prompt for inference
|
| 107 |
+
quantization_type: Model quantization type
|
| 108 |
+
model_selection: Selected model name
|
| 109 |
+
progress: Gradio progress object
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Tuple of UI update states and results
|
| 113 |
+
"""
|
| 114 |
+
# Reset stop flag at the beginning of processing
|
| 115 |
+
self.reset_stop_flag()
|
| 116 |
+
|
| 117 |
+
# Extract file dictionary
|
| 118 |
+
file_dict = extract_file_dict(folder_path)
|
| 119 |
+
|
| 120 |
+
# Print all file names for debug
|
| 121 |
+
for fname in file_dict:
|
| 122 |
+
print(fname)
|
| 123 |
+
|
| 124 |
+
validation_result, message = validate_data(file_dict)
|
| 125 |
+
|
| 126 |
+
# Handle different validation results
|
| 127 |
+
if validation_result == False:
|
| 128 |
+
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), message, gr.update(visible=False), ""
|
| 129 |
+
elif validation_result in ["no_csv", "multiple_csv"]:
|
| 130 |
+
return self._process_without_csv(file_dict, prompt, quantization_type, model_selection, progress)
|
| 131 |
+
else:
|
| 132 |
+
return self._process_with_csv(file_dict, prompt, quantization_type, model_selection, progress)
|
| 133 |
+
|
| 134 |
+
def _process_without_csv(
|
| 135 |
+
self,
|
| 136 |
+
file_dict: Dict[str, Path],
|
| 137 |
+
prompt: str,
|
| 138 |
+
quantization_type: str,
|
| 139 |
+
model_selection: str,
|
| 140 |
+
progress: gr.Progress()
|
| 141 |
+
) -> Tuple[Any, ...]:
|
| 142 |
+
"""Process images without CSV file."""
|
| 143 |
+
image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']
|
| 144 |
+
image_file_dict = {fname: file_dict[fname] for fname in file_dict
|
| 145 |
+
if any(fname.lower().endswith(ext) for ext in image_exts)}
|
| 146 |
+
|
| 147 |
+
filtered_rows = []
|
| 148 |
+
total_images = len(image_file_dict)
|
| 149 |
+
|
| 150 |
+
if total_images == 0:
|
| 151 |
+
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), "No image files found.", gr.update(visible=False), ""
|
| 152 |
+
|
| 153 |
+
# Ensure correct model is loaded
|
| 154 |
+
self._ensure_correct_model_loaded(model_selection, quantization_type, progress)
|
| 155 |
+
|
| 156 |
+
# Initialize progress
|
| 157 |
+
progress(0, desc=f"π Starting to process {total_images} images...")
|
| 158 |
+
print(f"Starting to process {total_images} images with {model_selection}...")
|
| 159 |
+
|
| 160 |
+
for idx, (img_name, img_path) in enumerate(image_file_dict.items()):
|
| 161 |
+
# Check stop flag before processing each image
|
| 162 |
+
if self.check_stop_flag():
|
| 163 |
+
print(f"π Processing stopped by user at image {idx + 1}/{total_images}")
|
| 164 |
+
# Add remaining images as "Not processed" entries
|
| 165 |
+
for remaining_idx, (remaining_name, remaining_path) in enumerate(list(image_file_dict.items())[idx:]):
|
| 166 |
+
filtered_rows.append({
|
| 167 |
+
'S.No': idx + remaining_idx + 1,
|
| 168 |
+
'Image Name': remaining_name,
|
| 169 |
+
'Ground Truth': '',
|
| 170 |
+
'Binary Output': 'Not processed (stopped)',
|
| 171 |
+
'Model Output': 'Processing stopped by user',
|
| 172 |
+
'Image Path': str(remaining_path)
|
| 173 |
+
})
|
| 174 |
+
|
| 175 |
+
display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']]
|
| 176 |
+
self.full_df = pd.DataFrame(filtered_rows)
|
| 177 |
+
final_message = f"π Processing stopped by user. Completed {idx}/{total_images} images."
|
| 178 |
+
print(final_message)
|
| 179 |
+
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
# Update progress with current image info
|
| 183 |
+
current_progress = idx / total_images
|
| 184 |
+
progress_msg = f"π Processing image {idx + 1}/{total_images}: {img_name[:30]}..." if len(img_name) > 30 else f"π Processing image {idx + 1}/{total_images}: {img_name}"
|
| 185 |
+
progress(current_progress, desc=progress_msg)
|
| 186 |
+
print(progress_msg)
|
| 187 |
+
|
| 188 |
+
# Use model inference
|
| 189 |
+
model_output = self.model_manager.inference(str(img_path), prompt) if prompt else "No prompt provided"
|
| 190 |
+
|
| 191 |
+
# Extract binary output (no ground truth available for file-based processing)
|
| 192 |
+
binary_output = extract_binary_output(model_output, "", [])
|
| 193 |
+
|
| 194 |
+
filtered_rows.append({
|
| 195 |
+
'S.No': idx + 1,
|
| 196 |
+
'Image Name': img_name,
|
| 197 |
+
'Ground Truth': '', # Empty for manual input
|
| 198 |
+
'Binary Output': binary_output,
|
| 199 |
+
'Model Output': model_output,
|
| 200 |
+
'Image Path': str(img_path)
|
| 201 |
+
})
|
| 202 |
+
|
| 203 |
+
# Update progress after successful processing
|
| 204 |
+
current_progress = (idx + 1) / total_images
|
| 205 |
+
progress_msg = f"β
Completed {idx + 1}/{total_images} images"
|
| 206 |
+
progress(current_progress, desc=progress_msg)
|
| 207 |
+
print(f"Successfully processed image {idx + 1} of {total_images}")
|
| 208 |
+
|
| 209 |
+
except Exception as e:
|
| 210 |
+
print(f"Error processing image {idx + 1} of {total_images}: {str(e)}")
|
| 211 |
+
filtered_rows.append({
|
| 212 |
+
'S.No': idx + 1,
|
| 213 |
+
'Image Name': img_name,
|
| 214 |
+
'Ground Truth': '',
|
| 215 |
+
'Binary Output': 'Enter the output manually', # Default for errors
|
| 216 |
+
'Model Output': f"Error: {str(e)}",
|
| 217 |
+
'Image Path': str(img_path)
|
| 218 |
+
})
|
| 219 |
+
|
| 220 |
+
# Update progress even for errors
|
| 221 |
+
current_progress = (idx + 1) / total_images
|
| 222 |
+
progress_msg = f"β οΈ Processed {idx + 1}/{total_images} images (with errors)"
|
| 223 |
+
progress(current_progress, desc=progress_msg)
|
| 224 |
+
|
| 225 |
+
# Check if processing was completed or stopped
|
| 226 |
+
if self.check_stop_flag():
|
| 227 |
+
final_message = f"π Processing stopped by user. Completed {len(filtered_rows)}/{total_images} images."
|
| 228 |
+
else:
|
| 229 |
+
final_message = f"π Successfully completed processing all {total_images} images!"
|
| 230 |
+
|
| 231 |
+
display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']]
|
| 232 |
+
# Save the full dataframe (with Image Path) for preview
|
| 233 |
+
self.full_df = pd.DataFrame(filtered_rows)
|
| 234 |
+
self.save_results_to_s3(display_df)
|
| 235 |
+
|
| 236 |
+
print(final_message)
|
| 237 |
+
|
| 238 |
+
# Make the table editable for ground truth input
|
| 239 |
+
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message
|
| 240 |
+
|
| 241 |
+
def _process_with_csv(
|
| 242 |
+
self,
|
| 243 |
+
file_dict: Dict[str, Path],
|
| 244 |
+
prompt: str,
|
| 245 |
+
quantization_type: str,
|
| 246 |
+
model_selection: str,
|
| 247 |
+
progress: gr.Progress()
|
| 248 |
+
) -> Tuple[Any, ...]:
|
| 249 |
+
"""Process images with CSV file."""
|
| 250 |
+
csv_files = [fname for fname in file_dict if fname.lower().endswith('.csv')]
|
| 251 |
+
csv_file = file_dict[csv_files[0]]
|
| 252 |
+
df = pd.read_csv(csv_file)
|
| 253 |
+
|
| 254 |
+
# Collect all ground truth values for unique keyword extraction
|
| 255 |
+
all_ground_truths = [str(row['Ground Truth']) for idx, row in df.iterrows()
|
| 256 |
+
if pd.notna(row['Ground Truth']) and str(row['Ground Truth']).strip()]
|
| 257 |
+
|
| 258 |
+
# Find image files
|
| 259 |
+
image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']
|
| 260 |
+
image_file_dict = {fname: file_dict[fname] for fname in file_dict
|
| 261 |
+
if any(fname.lower().endswith(ext) for ext in image_exts)}
|
| 262 |
+
|
| 263 |
+
# Only keep rows where image file exists
|
| 264 |
+
filtered_rows = []
|
| 265 |
+
matching_images = [row for idx, row in df.iterrows() if row['Image Name'] in image_file_dict]
|
| 266 |
+
total_images = len(matching_images)
|
| 267 |
+
|
| 268 |
+
if total_images == 0:
|
| 269 |
+
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), "No matching images found for entries in CSV.", gr.update(visible=False), ""
|
| 270 |
+
|
| 271 |
+
# Ensure correct model is loaded
|
| 272 |
+
self._ensure_correct_model_loaded(model_selection, quantization_type, progress)
|
| 273 |
+
|
| 274 |
+
# Initialize progress
|
| 275 |
+
progress(0, desc=f"π Starting to process {total_images} images...")
|
| 276 |
+
print(f"Starting to process {total_images} images with {model_selection}...")
|
| 277 |
+
processed_count = 0
|
| 278 |
+
|
| 279 |
+
for idx, row in df.iterrows():
|
| 280 |
+
img_name = row['Image Name']
|
| 281 |
+
if img_name in image_file_dict:
|
| 282 |
+
# Check stop flag before processing each image
|
| 283 |
+
if self.check_stop_flag():
|
| 284 |
+
print(f"π Processing stopped by user at image {processed_count + 1}/{total_images}")
|
| 285 |
+
# Add remaining unprocessed images
|
| 286 |
+
for remaining_idx, remaining_row in df.iloc[idx:].iterrows():
|
| 287 |
+
if remaining_row['Image Name'] in image_file_dict:
|
| 288 |
+
filtered_rows.append({
|
| 289 |
+
'S.No': len(filtered_rows) + 1,
|
| 290 |
+
'Image Name': remaining_row['Image Name'],
|
| 291 |
+
'Ground Truth': remaining_row['Ground Truth'],
|
| 292 |
+
'Binary Output': 'Not processed (stopped)',
|
| 293 |
+
'Model Output': 'Processing stopped by user',
|
| 294 |
+
'Image Path': str(image_file_dict[remaining_row['Image Name']])
|
| 295 |
+
})
|
| 296 |
+
|
| 297 |
+
display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']]
|
| 298 |
+
self.full_df = pd.DataFrame(filtered_rows)
|
| 299 |
+
final_message = f"π Processing stopped by user. Completed {processed_count}/{total_images} images."
|
| 300 |
+
print(final_message)
|
| 301 |
+
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message
|
| 302 |
+
|
| 303 |
+
try:
|
| 304 |
+
processed_count += 1
|
| 305 |
+
# Update progress with current image info
|
| 306 |
+
current_progress = (processed_count - 1) / total_images
|
| 307 |
+
progress_msg = f"π Processing image {processed_count}/{total_images}: {img_name[:30]}..." if len(img_name) > 30 else f"π Processing image {processed_count}/{total_images}: {img_name}"
|
| 308 |
+
progress(current_progress, desc=progress_msg)
|
| 309 |
+
print(progress_msg)
|
| 310 |
+
|
| 311 |
+
# Use model inference
|
| 312 |
+
model_output = self.model_manager.inference(str(image_file_dict[img_name]), prompt)
|
| 313 |
+
|
| 314 |
+
# Extract binary output using ground truth and all ground truths for keyword extraction
|
| 315 |
+
ground_truth = str(row['Ground Truth']) if pd.notna(row['Ground Truth']) else ""
|
| 316 |
+
binary_output = extract_binary_output(model_output, ground_truth, all_ground_truths)
|
| 317 |
+
|
| 318 |
+
filtered_rows.append({
|
| 319 |
+
'S.No': len(filtered_rows) + 1,
|
| 320 |
+
'Image Name': img_name,
|
| 321 |
+
'Ground Truth': row['Ground Truth'],
|
| 322 |
+
'Binary Output': binary_output,
|
| 323 |
+
'Model Output': model_output,
|
| 324 |
+
'Image Path': str(image_file_dict[img_name])
|
| 325 |
+
})
|
| 326 |
+
|
| 327 |
+
# Update progress after successful processing
|
| 328 |
+
current_progress = processed_count / total_images
|
| 329 |
+
progress_msg = f"β
Completed {processed_count}/{total_images} images"
|
| 330 |
+
progress(current_progress, desc=progress_msg)
|
| 331 |
+
print(f"Successfully processed image {processed_count} of {total_images}")
|
| 332 |
+
|
| 333 |
+
except Exception as e:
|
| 334 |
+
print(f"Error processing image {processed_count} of {total_images}: {str(e)}")
|
| 335 |
+
filtered_rows.append({
|
| 336 |
+
'S.No': len(filtered_rows) + 1,
|
| 337 |
+
'Image Name': img_name,
|
| 338 |
+
'Ground Truth': row['Ground Truth'],
|
| 339 |
+
'Binary Output': 'Enter the output manually', # Default for errors
|
| 340 |
+
'Model Output': f"Error: {str(e)}",
|
| 341 |
+
'Image Path': str(image_file_dict[img_name])
|
| 342 |
+
})
|
| 343 |
+
|
| 344 |
+
# Update progress even for errors
|
| 345 |
+
current_progress = processed_count / total_images
|
| 346 |
+
progress_msg = f"β οΈ Processed {processed_count}/{total_images} images (with errors)"
|
| 347 |
+
progress(current_progress, desc=progress_msg)
|
| 348 |
+
|
| 349 |
+
# Check if processing was completed or stopped
|
| 350 |
+
if self.check_stop_flag():
|
| 351 |
+
final_message = f"π Processing stopped by user. Completed {len([r for r in filtered_rows if 'stopped' not in r['Model Output']])}/{total_images} images."
|
| 352 |
+
else:
|
| 353 |
+
final_message = f"π Successfully completed processing all {total_images} images!"
|
| 354 |
+
|
| 355 |
+
display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']]
|
| 356 |
+
# Save the full dataframe (with Image Path) for preview
|
| 357 |
+
self.full_df = pd.DataFrame(filtered_rows)
|
| 358 |
+
|
| 359 |
+
self.save_results_to_s3(display_df)
|
| 360 |
+
|
| 361 |
+
print(final_message)
|
| 362 |
+
|
| 363 |
+
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message
|
| 364 |
+
|
| 365 |
+
def rerun_with_new_prompt(
|
| 366 |
+
self,
|
| 367 |
+
df: pd.DataFrame,
|
| 368 |
+
new_prompt: str,
|
| 369 |
+
quantization_type: str,
|
| 370 |
+
model_selection: str,
|
| 371 |
+
progress: gr.Progress()
|
| 372 |
+
) -> Tuple[Any, ...]:
|
| 373 |
+
"""Rerun processing with new prompt and clear accuracy data."""
|
| 374 |
+
if df is None or not new_prompt.strip():
|
| 375 |
+
return df, None, None, None, gr.update(visible=False), gr.update(visible=False), "β οΈ Please provide a valid prompt"
|
| 376 |
+
|
| 377 |
+
# Reset stop flag at the beginning of reprocessing
|
| 378 |
+
self.reset_stop_flag()
|
| 379 |
+
|
| 380 |
+
updated_df = df.copy()
|
| 381 |
+
total_images = len(updated_df)
|
| 382 |
+
|
| 383 |
+
# Collect all ground truth values for unique keyword extraction
|
| 384 |
+
all_ground_truths = [str(row['Ground Truth']) for idx, row in updated_df.iterrows()
|
| 385 |
+
if pd.notna(row['Ground Truth']) and str(row['Ground Truth']).strip()]
|
| 386 |
+
|
| 387 |
+
# Get the full dataframe with image paths
|
| 388 |
+
if self.full_df is None:
|
| 389 |
+
return df, None, None, None, gr.update(visible=False), gr.update(visible=False), "β οΈ No image data available"
|
| 390 |
+
|
| 391 |
+
# Create a copy of the full dataframe to update
|
| 392 |
+
updated_full_df = self.full_df.copy()
|
| 393 |
+
|
| 394 |
+
# Ensure correct model is loaded
|
| 395 |
+
self._ensure_correct_model_loaded(model_selection, quantization_type, progress)
|
| 396 |
+
|
| 397 |
+
# Initialize progress
|
| 398 |
+
progress(0, desc=f"π Starting to reprocess {total_images} images with new prompt...")
|
| 399 |
+
print(f"π Starting to reprocess {total_images} images with new prompt...")
|
| 400 |
+
|
| 401 |
+
for i in range(len(updated_df)):
|
| 402 |
+
# Check stop flag before processing each image
|
| 403 |
+
if self.check_stop_flag():
|
| 404 |
+
print(f"π Reprocessing stopped by user at image {i + 1}/{total_images}")
|
| 405 |
+
# Mark remaining images as not reprocessed in both dataframes
|
| 406 |
+
for j in range(i, len(updated_df)):
|
| 407 |
+
updated_df.iloc[j, updated_df.columns.get_loc("Model Output")] = "Reprocessing stopped by user"
|
| 408 |
+
updated_df.iloc[j, updated_df.columns.get_loc("Binary Output")] = "Not reprocessed (stopped)"
|
| 409 |
+
# Also update the full dataframe
|
| 410 |
+
if j < len(updated_full_df):
|
| 411 |
+
updated_full_df.iloc[j, updated_full_df.columns.get_loc("Model Output")] = "Reprocessing stopped by user"
|
| 412 |
+
updated_full_df.iloc[j, updated_full_df.columns.get_loc("Binary Output")] = "Not reprocessed (stopped)"
|
| 413 |
+
|
| 414 |
+
# Update the full_df reference
|
| 415 |
+
self.full_df = updated_full_df
|
| 416 |
+
|
| 417 |
+
final_message = f"π Reprocessing stopped by user. Completed {i}/{total_images} images."
|
| 418 |
+
print(final_message)
|
| 419 |
+
return updated_df, None, None, None, gr.update(visible=False), gr.update(visible=False), final_message
|
| 420 |
+
|
| 421 |
+
try:
|
| 422 |
+
# Get image path from full_df
|
| 423 |
+
image_path = self.full_df.iloc[i]['Image Path']
|
| 424 |
+
image_name = updated_df.iloc[i]['Image Name']
|
| 425 |
+
ground_truth = str(updated_df.iloc[i]['Ground Truth']) if pd.notna(updated_df.iloc[i]['Ground Truth']) else ""
|
| 426 |
+
|
| 427 |
+
# Update progress with current image info
|
| 428 |
+
current_progress = i / total_images
|
| 429 |
+
progress_msg = f"π Reprocessing image {i + 1}/{total_images}: {image_name[:30]}..." if len(image_name) > 30 else f"π Reprocessing image {i + 1}/{total_images}: {image_name}"
|
| 430 |
+
progress(current_progress, desc=progress_msg)
|
| 431 |
+
print(progress_msg)
|
| 432 |
+
|
| 433 |
+
# Use model inference with new prompt
|
| 434 |
+
model_output = self.model_manager.inference(image_path, new_prompt)
|
| 435 |
+
|
| 436 |
+
# Update both the display dataframe and the full dataframe
|
| 437 |
+
updated_df.iloc[i, updated_df.columns.get_loc("Model Output")] = model_output
|
| 438 |
+
updated_full_df.iloc[i, updated_full_df.columns.get_loc("Model Output")] = model_output
|
| 439 |
+
|
| 440 |
+
# Extract binary output using ground truth and all ground truths for keyword extraction
|
| 441 |
+
binary_output = extract_binary_output(model_output, ground_truth, all_ground_truths)
|
| 442 |
+
updated_df.iloc[i, updated_df.columns.get_loc("Binary Output")] = binary_output
|
| 443 |
+
updated_full_df.iloc[i, updated_full_df.columns.get_loc("Binary Output")] = binary_output
|
| 444 |
+
|
| 445 |
+
# Update progress after successful processing
|
| 446 |
+
current_progress = (i + 1) / total_images
|
| 447 |
+
progress_msg = f"β
Completed {i + 1}/{total_images} images"
|
| 448 |
+
progress(current_progress, desc=progress_msg)
|
| 449 |
+
print(f"β
Successfully reprocessed image {i + 1}/{total_images}")
|
| 450 |
+
|
| 451 |
+
except Exception as e:
|
| 452 |
+
print(f"β Error reprocessing image {i + 1}/{total_images}: {str(e)}")
|
| 453 |
+
error_message = f"Error: {str(e)}"
|
| 454 |
+
|
| 455 |
+
# Update both dataframes with error information
|
| 456 |
+
updated_df.iloc[i, updated_df.columns.get_loc("Model Output")] = error_message
|
| 457 |
+
updated_df.iloc[i, updated_df.columns.get_loc("Binary Output")] = "Enter the output manually"
|
| 458 |
+
updated_full_df.iloc[i, updated_full_df.columns.get_loc("Model Output")] = error_message
|
| 459 |
+
updated_full_df.iloc[i, updated_full_df.columns.get_loc("Binary Output")] = "Enter the output manually"
|
| 460 |
+
|
| 461 |
+
# Update progress even for errors
|
| 462 |
+
current_progress = (i + 1) / total_images
|
| 463 |
+
progress_msg = f"β οΈ Processed {i + 1}/{total_images} images (with errors)"
|
| 464 |
+
progress(current_progress, desc=progress_msg)
|
| 465 |
+
|
| 466 |
+
# Update the full_df reference with the updated data
|
| 467 |
+
self.full_df = updated_full_df
|
| 468 |
+
|
| 469 |
+
# Check if reprocessing was completed or stopped
|
| 470 |
+
if self.check_stop_flag():
|
| 471 |
+
final_message = f"π Reprocessing stopped by user. Completed reprocessing for some images."
|
| 472 |
+
else:
|
| 473 |
+
final_message = f"π Successfully completed reprocessing all {total_images} images with new prompt! Click 'Generate Metrics' to see accuracy data."
|
| 474 |
+
self.save_results_to_s3(updated_full_df)
|
| 475 |
+
|
| 476 |
+
print(final_message)
|
| 477 |
+
|
| 478 |
+
# Return updated dataframe and clear accuracy data (hide section 3)
|
| 479 |
+
return updated_df, None, None, None, gr.update(visible=False), gr.update(visible=False), final_message
|
| 480 |
+
|
| 481 |
+
def save_results_to_s3(self, df):
|
| 482 |
+
"""Save results to S3 bucket."""
|
| 483 |
+
try:
|
| 484 |
+
s3_bucket = os.getenv('AWS_BUCKET')
|
| 485 |
+
prefix = os.getenv('AWS_PREFIX')
|
| 486 |
+
s3_path = f"{prefix}/{datetime.now().date()}"
|
| 487 |
+
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 488 |
+
csv_file_name = f'{date_time}_model_output.csv'
|
| 489 |
+
|
| 490 |
+
# create accuracy table
|
| 491 |
+
metrics_df, _, cm_values = create_accuracy_table(df)
|
| 492 |
+
# save metrics_df to text file
|
| 493 |
+
|
| 494 |
+
text_file_name = f'{date_time}_evaluation_metrics.txt'
|
| 495 |
+
# save metrics_df to text file
|
| 496 |
+
with open(text_file_name, 'w') as f:
|
| 497 |
+
f.write(metrics_df.to_string() + '\n\n')
|
| 498 |
+
f.write(cm_values.to_string())
|
| 499 |
+
|
| 500 |
+
# save df to csv
|
| 501 |
+
df.to_csv(csv_file_name, index=False)
|
| 502 |
+
|
| 503 |
+
# upload files to s3
|
| 504 |
+
status = self.upload_file(text_file_name, s3_bucket, f"{s3_path}/{text_file_name}")
|
| 505 |
+
print(f"Status of uploading {text_file_name} to {s3_bucket}/{s3_path}/{text_file_name}: {status}")
|
| 506 |
+
status = self.upload_file(csv_file_name, s3_bucket, f"{s3_path}/{csv_file_name}")
|
| 507 |
+
print(f"Status of uploading {csv_file_name} to {s3_bucket}/{s3_path}/{csv_file_name}: {status}")
|
| 508 |
+
|
| 509 |
+
# delete files from local
|
| 510 |
+
os.remove(text_file_name)
|
| 511 |
+
os.remove(csv_file_name)
|
| 512 |
+
print(f"Deleted {text_file_name} and {csv_file_name}")
|
| 513 |
+
except Exception as e:
|
| 514 |
+
print(f"Error saving results to s3: {e}")
|
| 515 |
+
if "No valid data" in str(e) or "Need at least 2 different" in str(e):
|
| 516 |
+
df.to_csv(csv_file_name, index=False)
|
| 517 |
+
status = self.upload_file(csv_file_name, s3_bucket, f"{s3_path}/{csv_file_name}")
|
| 518 |
+
print(f"Status of uploading only csv file to {s3_bucket}/{s3_path}/{csv_file_name}: {status}")
|
| 519 |
+
os.remove(csv_file_name)
|
| 520 |
+
print(f"Deleted {csv_file_name}")
|
| 521 |
+
|
| 522 |
+
def upload_file(self,file_name, bucket, object_name=None):
|
| 523 |
+
"""Upload a file to an S3 bucket
|
| 524 |
+
|
| 525 |
+
:param file_name: File to upload
|
| 526 |
+
:param bucket: Bucket to upload to
|
| 527 |
+
:param object_name: S3 object name. If not specified then file_name is used
|
| 528 |
+
:return: True if file was uploaded, else False
|
| 529 |
+
"""
|
| 530 |
+
access_key = os.getenv('AWS_ACCESS_KEY_ID')
|
| 531 |
+
secret_key = os.getenv('AWS_SECRET_ACCESS_KEY')
|
| 532 |
+
# If S3 object_name was not specified, use file_name
|
| 533 |
+
if object_name is None:
|
| 534 |
+
object_name = os.path.basename(file_name)
|
| 535 |
+
|
| 536 |
+
# Upload the file
|
| 537 |
+
s3_client = boto3.client('s3', aws_access_key_id=access_key, aws_secret_access_key=secret_key)
|
| 538 |
+
try:
|
| 539 |
+
response = s3_client.upload_file(file_name, bucket, object_name)
|
| 540 |
+
except Exception as e:
|
| 541 |
+
print(f"Error uploading {file_name} to s3: {e}")
|
| 542 |
+
return False
|
| 543 |
+
return True
|
backend/models/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_model import BaseModel
|
| 2 |
+
from .model_manager import ModelManager
|
| 3 |
+
from .internvl import InternVLModel
|
| 4 |
+
from .qwen import QwenModel
|
| 5 |
+
|
| 6 |
+
__all__ = ['BaseModel', 'ModelManager', 'InternVLModel', 'QwenModel']
|
backend/models/base_model.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Dict, Any, Optional, List
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import AutoModel, AutoTokenizer
|
| 5 |
+
|
| 6 |
+
class BaseModel(ABC):
|
| 7 |
+
"""Abstract base class for all vision-language models."""
|
| 8 |
+
|
| 9 |
+
def __init__(self, model_name: str, model_config: Dict[str, Any]):
|
| 10 |
+
"""
|
| 11 |
+
Initialize the base model.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
model_name: Name of the model
|
| 15 |
+
model_config: Configuration dictionary for the model
|
| 16 |
+
"""
|
| 17 |
+
self.model_name = model_name
|
| 18 |
+
self.model_config = model_config
|
| 19 |
+
self.model_id = model_config['model_id']
|
| 20 |
+
self.model = None
|
| 21 |
+
self.tokenizer = None
|
| 22 |
+
self.current_quantization = None
|
| 23 |
+
self.is_loaded = False
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def load_model(self, quantization_type: str, **kwargs) -> bool:
|
| 27 |
+
"""
|
| 28 |
+
Load the model with specified quantization.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
quantization_type: Type of quantization to use
|
| 32 |
+
**kwargs: Additional arguments for model loading
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
True if successful, False otherwise
|
| 36 |
+
"""
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def unload_model(self) -> None:
|
| 41 |
+
"""Unload the model from memory."""
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
@abstractmethod
|
| 45 |
+
def inference(self, image_path: str, prompt: str, **kwargs) -> str:
|
| 46 |
+
"""
|
| 47 |
+
Perform inference on an image with a text prompt.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
image_path: Path to the image file
|
| 51 |
+
prompt: Text prompt for the model
|
| 52 |
+
**kwargs: Additional inference parameters
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Model's text response
|
| 56 |
+
"""
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
def is_model_loaded(self) -> bool:
|
| 60 |
+
"""Check if model is currently loaded."""
|
| 61 |
+
return self.is_loaded
|
| 62 |
+
|
| 63 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 64 |
+
"""Get information about the model."""
|
| 65 |
+
return {
|
| 66 |
+
'name': self.model_name,
|
| 67 |
+
'model_id': self.model_id,
|
| 68 |
+
'description': self.model_config.get('description', ''),
|
| 69 |
+
'min_gpu_memory_gb': self.model_config.get('min_gpu_memory_gb', 0),
|
| 70 |
+
'recommended_gpu_memory_gb': self.model_config.get('recommended_gpu_memory_gb', 0),
|
| 71 |
+
'supported_quantizations': self.model_config.get('supported_quantizations', []),
|
| 72 |
+
'default_quantization': self.model_config.get('default_quantization', ''),
|
| 73 |
+
'is_loaded': self.is_loaded,
|
| 74 |
+
'current_quantization': self.current_quantization
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
def get_supported_quantizations(self) -> List[str]:
|
| 78 |
+
"""Get list of supported quantization methods."""
|
| 79 |
+
return self.model_config.get('supported_quantizations', [])
|
| 80 |
+
|
| 81 |
+
def get_memory_requirements(self) -> Dict[str, int]:
|
| 82 |
+
"""Get memory requirements for the model."""
|
| 83 |
+
return {
|
| 84 |
+
'min_gpu_memory_gb': self.model_config.get('min_gpu_memory_gb', 0),
|
| 85 |
+
'recommended_gpu_memory_gb': self.model_config.get('recommended_gpu_memory_gb', 0)
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
def validate_quantization(self, quantization_type: str) -> bool:
|
| 89 |
+
"""
|
| 90 |
+
Validate if the quantization type is supported.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
quantization_type: Quantization type to validate
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
True if supported, False otherwise
|
| 97 |
+
"""
|
| 98 |
+
supported = self.get_supported_quantizations()
|
| 99 |
+
return quantization_type in supported
|
| 100 |
+
|
| 101 |
+
def __str__(self) -> str:
|
| 102 |
+
"""String representation of the model."""
|
| 103 |
+
status = "loaded" if self.is_loaded else "not loaded"
|
| 104 |
+
quant = f" ({self.current_quantization})" if self.current_quantization else ""
|
| 105 |
+
return f"{self.model_name}{quant} - {status}"
|
| 106 |
+
|
| 107 |
+
def __repr__(self) -> str:
|
| 108 |
+
"""Detailed string representation."""
|
| 109 |
+
return f"BaseModel(name={self.model_name}, loaded={self.is_loaded}, quantization={self.current_quantization})"
|
backend/models/internvl/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .internvl_model import InternVLModel
|
| 2 |
+
|
| 3 |
+
__all__ = ['InternVLModel']
|
backend/models/internvl/internvl_model.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gc
|
| 3 |
+
import os
|
| 4 |
+
from typing import Dict, Any, Optional, Callable
|
| 5 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
| 6 |
+
from ..base_model import BaseModel
|
| 7 |
+
from ...utils.image_processing import load_image
|
| 8 |
+
from ...config.config_manager import ConfigManager
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class InternVLModel(BaseModel):
|
| 12 |
+
"""InternVL3 model implementation."""
|
| 13 |
+
|
| 14 |
+
def __init__(self, model_name: str, model_config: Dict[str, Any], config_manager: ConfigManager):
|
| 15 |
+
"""
|
| 16 |
+
Initialize the InternVL model.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
model_name: Name of the model
|
| 20 |
+
model_config: Configuration dictionary for the model
|
| 21 |
+
config_manager: Configuration manager instance
|
| 22 |
+
"""
|
| 23 |
+
super().__init__(model_name, model_config)
|
| 24 |
+
self.config_manager = config_manager
|
| 25 |
+
|
| 26 |
+
# Set environment variable for CUDA memory allocation
|
| 27 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
| 28 |
+
|
| 29 |
+
def check_model_exists_locally(self) -> bool:
|
| 30 |
+
"""Check if model exists locally in Hugging Face cache."""
|
| 31 |
+
try:
|
| 32 |
+
from transformers.utils import cached_file
|
| 33 |
+
cached_file(self.model_id, "config.json", local_files_only=True)
|
| 34 |
+
return True
|
| 35 |
+
except:
|
| 36 |
+
return False
|
| 37 |
+
|
| 38 |
+
def download_model_with_progress(self, progress_callback: Optional[Callable] = None) -> bool:
|
| 39 |
+
"""
|
| 40 |
+
Download model with progress tracking.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
progress_callback: Callback function for progress updates
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
True if successful, False otherwise
|
| 47 |
+
"""
|
| 48 |
+
try:
|
| 49 |
+
if progress_callback:
|
| 50 |
+
progress_callback("π₯ Downloading tokenizer...")
|
| 51 |
+
|
| 52 |
+
# Download tokenizer first (smaller)
|
| 53 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 54 |
+
self.model_id,
|
| 55 |
+
trust_remote_code=True,
|
| 56 |
+
use_fast=False
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
if progress_callback:
|
| 60 |
+
progress_callback("π₯ Downloading model weights... This may take several minutes...")
|
| 61 |
+
|
| 62 |
+
# Download model config and weights
|
| 63 |
+
config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
|
| 64 |
+
|
| 65 |
+
if progress_callback:
|
| 66 |
+
progress_callback("β
Model downloaded successfully!")
|
| 67 |
+
|
| 68 |
+
return True
|
| 69 |
+
except Exception as e:
|
| 70 |
+
if progress_callback:
|
| 71 |
+
progress_callback(f"β Download failed: {str(e)}")
|
| 72 |
+
return False
|
| 73 |
+
|
| 74 |
+
def split_model(self) -> Dict[str, int]:
|
| 75 |
+
"""
|
| 76 |
+
Distribute LLM layers across GPUs, keeping vision encoder on GPU 0.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Device map dictionary
|
| 80 |
+
"""
|
| 81 |
+
device_map = {}
|
| 82 |
+
world_size = torch.cuda.device_count()
|
| 83 |
+
|
| 84 |
+
if world_size < 2:
|
| 85 |
+
return "auto" # let transformers decide
|
| 86 |
+
|
| 87 |
+
cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
|
| 88 |
+
num_layers = cfg.llm_config.num_hidden_layers # type: ignore[attr-defined]
|
| 89 |
+
|
| 90 |
+
# More aggressive distribution - treat GPU 0 as 0.3 GPU capacity due to vision model
|
| 91 |
+
effective_gpus = world_size - 0.7 # More conservative for GPU 0
|
| 92 |
+
layers_per_gpu = num_layers / effective_gpus
|
| 93 |
+
|
| 94 |
+
# Calculate layer distribution
|
| 95 |
+
gpu_layers = []
|
| 96 |
+
for i in range(world_size):
|
| 97 |
+
if i == 0:
|
| 98 |
+
# GPU 0 gets fewer layers due to vision model
|
| 99 |
+
gpu_layers.append(max(1, int(layers_per_gpu * 0.3)))
|
| 100 |
+
else:
|
| 101 |
+
gpu_layers.append(int(layers_per_gpu))
|
| 102 |
+
|
| 103 |
+
# Adjust if total doesn't match num_layers
|
| 104 |
+
total_assigned = sum(gpu_layers)
|
| 105 |
+
diff = num_layers - total_assigned
|
| 106 |
+
if diff > 0:
|
| 107 |
+
# Add remaining layers to non-zero GPUs
|
| 108 |
+
for i in range(1, min(world_size, diff + 1)):
|
| 109 |
+
gpu_layers[i] += 1
|
| 110 |
+
elif diff < 0:
|
| 111 |
+
# Remove excess layers from GPU 0
|
| 112 |
+
gpu_layers[0] = max(1, gpu_layers[0] + diff)
|
| 113 |
+
|
| 114 |
+
# Assign layers to devices
|
| 115 |
+
layer_cnt = 0
|
| 116 |
+
for gpu_id, num_layers_on_gpu in enumerate(gpu_layers):
|
| 117 |
+
for _ in range(num_layers_on_gpu):
|
| 118 |
+
if layer_cnt < num_layers:
|
| 119 |
+
device_map[f'language_model.model.layers.{layer_cnt}'] = gpu_id
|
| 120 |
+
layer_cnt += 1
|
| 121 |
+
|
| 122 |
+
# Distribute other components more evenly across GPUs
|
| 123 |
+
last_gpu = world_size - 1
|
| 124 |
+
|
| 125 |
+
# Vision model must stay on GPU 0
|
| 126 |
+
device_map['vision_model'] = 0
|
| 127 |
+
device_map['mlp1'] = 0
|
| 128 |
+
|
| 129 |
+
# Distribute language model components across GPUs
|
| 130 |
+
device_map['language_model.model.tok_embeddings'] = 0
|
| 131 |
+
device_map['language_model.model.embed_tokens'] = 0
|
| 132 |
+
device_map['language_model.model.norm'] = last_gpu # Move to last GPU
|
| 133 |
+
device_map['language_model.model.rotary_emb'] = 1 if world_size > 1 else 0 # Move to GPU 1
|
| 134 |
+
device_map['language_model.output'] = last_gpu # Move to last GPU
|
| 135 |
+
device_map['language_model.lm_head'] = last_gpu # Move to last GPU
|
| 136 |
+
|
| 137 |
+
# Keep the last layer on the same GPU as output layers for compatibility
|
| 138 |
+
device_map[f'language_model.model.layers.{num_layers - 1}'] = last_gpu
|
| 139 |
+
|
| 140 |
+
print(f"Layer distribution: {gpu_layers}")
|
| 141 |
+
print(f"Total layers: {num_layers}, Assigned: {sum(gpu_layers)}")
|
| 142 |
+
|
| 143 |
+
return device_map
|
| 144 |
+
|
| 145 |
+
def load_model(self, quantization_type: str, progress_callback: Optional[Callable] = None) -> bool:
|
| 146 |
+
"""
|
| 147 |
+
Load the model with specified quantization.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
quantization_type: Type of quantization to use
|
| 151 |
+
progress_callback: Callback function for progress updates
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
True if successful, False otherwise
|
| 155 |
+
"""
|
| 156 |
+
if not self.validate_quantization(quantization_type):
|
| 157 |
+
raise ValueError(f"Quantization type '{quantization_type}' not supported for {self.model_name}")
|
| 158 |
+
|
| 159 |
+
# If model is already loaded with the same quantization, return
|
| 160 |
+
if (self.model is not None and self.tokenizer is not None and
|
| 161 |
+
self.current_quantization == quantization_type):
|
| 162 |
+
if progress_callback:
|
| 163 |
+
progress_callback(f"β
{self.model_name} already loaded!")
|
| 164 |
+
return True
|
| 165 |
+
|
| 166 |
+
print(f"Loading {self.model_name} with {quantization_type} quantization...")
|
| 167 |
+
if progress_callback:
|
| 168 |
+
progress_callback(f"π Loading {self.model_name} with {quantization_type} quantization...")
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
# Check if model exists locally
|
| 172 |
+
model_exists = self.check_model_exists_locally()
|
| 173 |
+
if not model_exists:
|
| 174 |
+
if progress_callback:
|
| 175 |
+
progress_callback(f"π₯ {self.model_name} not found locally. Starting download...")
|
| 176 |
+
print(f"Model {self.model_name} not found locally. Starting download...")
|
| 177 |
+
success = self.download_model_with_progress(progress_callback)
|
| 178 |
+
if not success:
|
| 179 |
+
raise Exception(f"Failed to download {self.model_name}")
|
| 180 |
+
else:
|
| 181 |
+
if progress_callback:
|
| 182 |
+
progress_callback(f"β
{self.model_name} found locally.")
|
| 183 |
+
|
| 184 |
+
# Clear existing model if any
|
| 185 |
+
if self.model is not None:
|
| 186 |
+
self.unload_model()
|
| 187 |
+
|
| 188 |
+
# Print memory before loading
|
| 189 |
+
self._print_gpu_memory("before loading")
|
| 190 |
+
|
| 191 |
+
if progress_callback:
|
| 192 |
+
progress_callback(f"π Loading {self.model_name} tokenizer...")
|
| 193 |
+
|
| 194 |
+
# Load tokenizer
|
| 195 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 196 |
+
self.model_id,
|
| 197 |
+
trust_remote_code=True,
|
| 198 |
+
use_fast=False
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Load model based on quantization type
|
| 202 |
+
if "non-quantized" in quantization_type:
|
| 203 |
+
if progress_callback:
|
| 204 |
+
progress_callback(f"π Loading {self.model_name} model in 16-bit precision...")
|
| 205 |
+
|
| 206 |
+
device_map = self.split_model()
|
| 207 |
+
print(f"Device map for multi-GPU: {device_map}")
|
| 208 |
+
|
| 209 |
+
# Try loading with custom device_map, fallback to "auto" if it fails
|
| 210 |
+
# Some InternVL models (e.g., InternVL3_5) don't support custom device_map
|
| 211 |
+
# due to missing 'all_tied_weights_keys' attribute
|
| 212 |
+
try:
|
| 213 |
+
self.model = AutoModel.from_pretrained(
|
| 214 |
+
self.model_id,
|
| 215 |
+
torch_dtype=torch.bfloat16,
|
| 216 |
+
low_cpu_mem_usage=True,
|
| 217 |
+
use_flash_attn=True,
|
| 218 |
+
trust_remote_code=True,
|
| 219 |
+
device_map=device_map,
|
| 220 |
+
).eval()
|
| 221 |
+
except (AttributeError, TypeError, RuntimeError, ValueError) as e:
|
| 222 |
+
error_str = str(e).lower()
|
| 223 |
+
# Check for device_map related errors, especially all_tied_weights_keys
|
| 224 |
+
# This is a known issue with some InternVL models that don't expose
|
| 225 |
+
# the all_tied_weights_keys attribute required for custom device_map
|
| 226 |
+
if ("all_tied_weights_keys" in error_str or
|
| 227 |
+
"tied_weights" in error_str or
|
| 228 |
+
("device_map" in error_str and "attribute" in error_str)):
|
| 229 |
+
print(f"β οΈ Custom device_map failed ({str(e)}), falling back to 'auto' device_map...")
|
| 230 |
+
if progress_callback:
|
| 231 |
+
progress_callback(f"β οΈ Using automatic device mapping...")
|
| 232 |
+
self.model = AutoModel.from_pretrained(
|
| 233 |
+
self.model_id,
|
| 234 |
+
torch_dtype=torch.bfloat16,
|
| 235 |
+
low_cpu_mem_usage=True,
|
| 236 |
+
use_flash_attn=True,
|
| 237 |
+
trust_remote_code=True,
|
| 238 |
+
device_map="auto",
|
| 239 |
+
).eval()
|
| 240 |
+
else:
|
| 241 |
+
# Re-raise if it's a different error
|
| 242 |
+
raise
|
| 243 |
+
else: # quantized (8bit)
|
| 244 |
+
if progress_callback:
|
| 245 |
+
progress_callback(f"π Loading {self.model_name} model with 8-bit quantization...")
|
| 246 |
+
|
| 247 |
+
print("Loading with 8-bit quantization to reduce memory usage...")
|
| 248 |
+
self.model = AutoModel.from_pretrained(
|
| 249 |
+
self.model_id,
|
| 250 |
+
torch_dtype=torch.bfloat16,
|
| 251 |
+
load_in_8bit=True,
|
| 252 |
+
low_cpu_mem_usage=True,
|
| 253 |
+
use_flash_attn=True,
|
| 254 |
+
trust_remote_code=True,
|
| 255 |
+
device_map="auto" # Let transformers handle device mapping for quantized model
|
| 256 |
+
).eval()
|
| 257 |
+
|
| 258 |
+
# Verify model and tokenizer are properly loaded
|
| 259 |
+
if self.model is None:
|
| 260 |
+
raise Exception(f"Model failed to load for {self.model_name}")
|
| 261 |
+
if self.tokenizer is None:
|
| 262 |
+
raise Exception(f"Tokenizer failed to load for {self.model_name}")
|
| 263 |
+
|
| 264 |
+
self.current_quantization = quantization_type
|
| 265 |
+
self.is_loaded = True
|
| 266 |
+
|
| 267 |
+
success_msg = f"β
{self.model_name} loaded successfully with {quantization_type} quantization!"
|
| 268 |
+
print(success_msg)
|
| 269 |
+
if progress_callback:
|
| 270 |
+
progress_callback(success_msg)
|
| 271 |
+
|
| 272 |
+
# Print GPU memory usage after loading
|
| 273 |
+
self._print_gpu_memory("after loading")
|
| 274 |
+
|
| 275 |
+
return True
|
| 276 |
+
|
| 277 |
+
except Exception as e:
|
| 278 |
+
error_msg = f"Failed to load model {self.model_name}: {str(e)}"
|
| 279 |
+
print(error_msg)
|
| 280 |
+
if progress_callback:
|
| 281 |
+
progress_callback(f"β {error_msg}")
|
| 282 |
+
|
| 283 |
+
# Reset on failure
|
| 284 |
+
self.unload_model()
|
| 285 |
+
raise Exception(error_msg)
|
| 286 |
+
|
| 287 |
+
def unload_model(self) -> None:
|
| 288 |
+
"""Unload the model from memory."""
|
| 289 |
+
if self.model is not None:
|
| 290 |
+
print("π§Ή Clearing model from memory...")
|
| 291 |
+
del self.model
|
| 292 |
+
self.model = None
|
| 293 |
+
|
| 294 |
+
if self.tokenizer is not None:
|
| 295 |
+
del self.tokenizer
|
| 296 |
+
self.tokenizer = None
|
| 297 |
+
|
| 298 |
+
self.current_quantization = None
|
| 299 |
+
self.is_loaded = False
|
| 300 |
+
|
| 301 |
+
# Clear GPU cache
|
| 302 |
+
if torch.cuda.is_available():
|
| 303 |
+
torch.cuda.empty_cache()
|
| 304 |
+
|
| 305 |
+
# Force garbage collection
|
| 306 |
+
gc.collect()
|
| 307 |
+
|
| 308 |
+
if torch.cuda.is_available():
|
| 309 |
+
torch.cuda.empty_cache() # Clear again after gc
|
| 310 |
+
|
| 311 |
+
print("β
Model unloaded successfully")
|
| 312 |
+
|
| 313 |
+
def inference(self, image_path: str, prompt: str, **kwargs) -> str:
|
| 314 |
+
"""
|
| 315 |
+
Perform inference on an image with a text prompt.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
image_path: Path to the image file
|
| 319 |
+
prompt: Text prompt for the model
|
| 320 |
+
**kwargs: Additional inference parameters
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
Model's text response
|
| 324 |
+
"""
|
| 325 |
+
if not self.is_loaded:
|
| 326 |
+
raise RuntimeError(f"Model {self.model_name} is not loaded. Call load_model() first.")
|
| 327 |
+
|
| 328 |
+
try:
|
| 329 |
+
# Load and preprocess image using default settings from original app.py
|
| 330 |
+
pixel_values = load_image(image_path, input_size=448, max_num=12).to(torch.bfloat16)
|
| 331 |
+
|
| 332 |
+
# Move pixel_values to the same device as the model
|
| 333 |
+
if torch.cuda.is_available():
|
| 334 |
+
# Get the device of the first model parameter
|
| 335 |
+
model_device = next(self.model.parameters()).device
|
| 336 |
+
pixel_values = pixel_values.to(model_device)
|
| 337 |
+
else:
|
| 338 |
+
# Fallback to CPU if no CUDA available
|
| 339 |
+
pixel_values = pixel_values.cpu()
|
| 340 |
+
|
| 341 |
+
# Prepare prompt
|
| 342 |
+
formatted_prompt = f"<image>\n{prompt}" if prompt else "<image>\n"
|
| 343 |
+
|
| 344 |
+
# Generation configuration - using same settings as original app.py
|
| 345 |
+
gen_cfg = dict(max_new_tokens=1024, do_sample=True)
|
| 346 |
+
|
| 347 |
+
# Perform inference
|
| 348 |
+
response = self.model.chat(self.tokenizer, pixel_values, formatted_prompt, gen_cfg)
|
| 349 |
+
return response
|
| 350 |
+
|
| 351 |
+
except Exception as e:
|
| 352 |
+
error_msg = f"Error processing image: {str(e)}"
|
| 353 |
+
print(error_msg)
|
| 354 |
+
return error_msg
|
| 355 |
+
|
| 356 |
+
def _print_gpu_memory(self, stage: str) -> None:
|
| 357 |
+
"""Print GPU memory usage for debugging."""
|
| 358 |
+
if torch.cuda.is_available():
|
| 359 |
+
print(f"Memory {stage}:")
|
| 360 |
+
for i in range(torch.cuda.device_count()):
|
| 361 |
+
allocated = torch.cuda.memory_allocated(i) / 1024**3
|
| 362 |
+
reserved = torch.cuda.memory_reserved(i) / 1024**3
|
| 363 |
+
print(f"GPU {i}: Allocated {allocated:.2f} GB, Reserved {reserved:.2f} GB")
|
backend/models/model_manager.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
from typing import Dict, Any, Optional, Callable
|
| 3 |
+
from .base_model import BaseModel
|
| 4 |
+
from .internvl import InternVLModel
|
| 5 |
+
from .qwen import QwenModel
|
| 6 |
+
from ..config.config_manager import ConfigManager
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ModelManager:
|
| 10 |
+
"""Manager class for handling multiple vision-language models."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, config_manager: ConfigManager):
|
| 13 |
+
"""
|
| 14 |
+
Initialize the model manager.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
config_manager: Configuration manager instance
|
| 18 |
+
"""
|
| 19 |
+
self.config_manager = config_manager
|
| 20 |
+
self.models: Dict[str, BaseModel] = {}
|
| 21 |
+
self.current_model: Optional[BaseModel] = None
|
| 22 |
+
self.current_model_name: Optional[str] = None
|
| 23 |
+
self.loading_lock = threading.Lock()
|
| 24 |
+
|
| 25 |
+
# Apply environment settings
|
| 26 |
+
self.config_manager.apply_environment_settings()
|
| 27 |
+
|
| 28 |
+
# Initialize models but don't load them yet
|
| 29 |
+
self._initialize_models()
|
| 30 |
+
|
| 31 |
+
def _get_model_class(self, model_config: Dict[str, Any]) -> type:
|
| 32 |
+
"""
|
| 33 |
+
Determine the appropriate model class based on model configuration.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
model_config: Model configuration dictionary
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Model class to instantiate
|
| 40 |
+
"""
|
| 41 |
+
model_type = model_config.get('model_type', 'internvl').lower()
|
| 42 |
+
model_id = model_config.get('model_id', '').lower()
|
| 43 |
+
|
| 44 |
+
# Determine model type based on model_id or explicit model_type
|
| 45 |
+
if 'qwen' in model_id or model_type == 'qwen':
|
| 46 |
+
return QwenModel
|
| 47 |
+
elif 'internvl' in model_id or model_type == 'internvl':
|
| 48 |
+
return InternVLModel
|
| 49 |
+
else:
|
| 50 |
+
# Default to InternVL for backward compatibility
|
| 51 |
+
print(f"β οΈ Unknown model type for {model_config.get('name', 'unknown')}, defaulting to InternVL")
|
| 52 |
+
return InternVLModel
|
| 53 |
+
|
| 54 |
+
def _initialize_models(self) -> None:
|
| 55 |
+
"""Initialize model instances without loading them."""
|
| 56 |
+
available_models = self.config_manager.get_available_models()
|
| 57 |
+
|
| 58 |
+
for model_name, model_id in available_models.items():
|
| 59 |
+
model_config = self.config_manager.get_model_config(model_name)
|
| 60 |
+
|
| 61 |
+
# Determine the appropriate model class
|
| 62 |
+
model_class = self._get_model_class(model_config)
|
| 63 |
+
|
| 64 |
+
# Create model instance
|
| 65 |
+
self.models[model_name] = model_class(
|
| 66 |
+
model_name=model_name,
|
| 67 |
+
model_config=model_config,
|
| 68 |
+
config_manager=self.config_manager
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
print(f"β
Initialized {model_class.__name__}: {model_name}")
|
| 72 |
+
|
| 73 |
+
def get_available_models(self) -> list[str]:
|
| 74 |
+
"""Get list of available model names."""
|
| 75 |
+
return list(self.models.keys())
|
| 76 |
+
|
| 77 |
+
def get_model_info(self, model_name: str) -> Dict[str, Any]:
|
| 78 |
+
"""
|
| 79 |
+
Get information about a specific model.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
model_name: Name of the model
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Model information dictionary
|
| 86 |
+
"""
|
| 87 |
+
if model_name not in self.models:
|
| 88 |
+
raise KeyError(f"Model '{model_name}' not available")
|
| 89 |
+
|
| 90 |
+
return self.models[model_name].get_model_info()
|
| 91 |
+
|
| 92 |
+
def get_all_models_info(self) -> Dict[str, Dict[str, Any]]:
|
| 93 |
+
"""Get information about all available models."""
|
| 94 |
+
return {name: model.get_model_info() for name, model in self.models.items()}
|
| 95 |
+
|
| 96 |
+
def load_model(
|
| 97 |
+
self,
|
| 98 |
+
model_name: str,
|
| 99 |
+
quantization_type: str,
|
| 100 |
+
progress_callback: Optional[Callable] = None
|
| 101 |
+
) -> bool:
|
| 102 |
+
"""
|
| 103 |
+
Load a specific model with given quantization.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
model_name: Name of the model to load
|
| 107 |
+
quantization_type: Type of quantization to use
|
| 108 |
+
progress_callback: Callback function for progress updates
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
True if successful, False otherwise
|
| 112 |
+
"""
|
| 113 |
+
with self.loading_lock:
|
| 114 |
+
if model_name not in self.models:
|
| 115 |
+
raise KeyError(f"Model '{model_name}' not available")
|
| 116 |
+
|
| 117 |
+
model = self.models[model_name]
|
| 118 |
+
|
| 119 |
+
# Check if this model is already loaded with the same quantization
|
| 120 |
+
if (self.current_model == model and
|
| 121 |
+
model.is_model_loaded() and
|
| 122 |
+
model.current_quantization == quantization_type):
|
| 123 |
+
if progress_callback:
|
| 124 |
+
progress_callback(f"β
{model_name} already loaded with {quantization_type}!")
|
| 125 |
+
return True
|
| 126 |
+
|
| 127 |
+
# Unload current model if different
|
| 128 |
+
if (self.current_model and
|
| 129 |
+
self.current_model != model and
|
| 130 |
+
self.current_model.is_model_loaded()):
|
| 131 |
+
if progress_callback:
|
| 132 |
+
progress_callback(f"π Unloading {self.current_model_name}...")
|
| 133 |
+
self.current_model.unload_model()
|
| 134 |
+
|
| 135 |
+
# Load the requested model
|
| 136 |
+
try:
|
| 137 |
+
success = model.load_model(quantization_type, progress_callback)
|
| 138 |
+
if success:
|
| 139 |
+
self.current_model = model
|
| 140 |
+
self.current_model_name = model_name
|
| 141 |
+
print(f"β
Successfully loaded {model_name} with {quantization_type}")
|
| 142 |
+
return True
|
| 143 |
+
else:
|
| 144 |
+
if progress_callback:
|
| 145 |
+
progress_callback(f"β Failed to load {model_name}")
|
| 146 |
+
return False
|
| 147 |
+
except Exception as e:
|
| 148 |
+
error_msg = f"Error loading {model_name}: {str(e)}"
|
| 149 |
+
print(error_msg)
|
| 150 |
+
if progress_callback:
|
| 151 |
+
progress_callback(f"β {error_msg}")
|
| 152 |
+
return False
|
| 153 |
+
|
| 154 |
+
def unload_current_model(self) -> None:
|
| 155 |
+
"""Unload the currently loaded model."""
|
| 156 |
+
with self.loading_lock:
|
| 157 |
+
if self.current_model and self.current_model.is_model_loaded():
|
| 158 |
+
print(f"π Unloading {self.current_model_name}...")
|
| 159 |
+
self.current_model.unload_model()
|
| 160 |
+
self.current_model = None
|
| 161 |
+
self.current_model_name = None
|
| 162 |
+
print("β
Model unloaded successfully")
|
| 163 |
+
else:
|
| 164 |
+
print("βΉοΈ No model currently loaded")
|
| 165 |
+
|
| 166 |
+
def inference(self, image_path: str, prompt: str, **kwargs) -> str:
|
| 167 |
+
"""
|
| 168 |
+
Perform inference using the currently loaded model.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
image_path: Path to the image file
|
| 172 |
+
prompt: Text prompt for the model
|
| 173 |
+
**kwargs: Additional inference parameters
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Model's text response
|
| 177 |
+
"""
|
| 178 |
+
if not self.current_model or not self.current_model.is_model_loaded():
|
| 179 |
+
raise RuntimeError("No model is currently loaded. Load a model first.")
|
| 180 |
+
|
| 181 |
+
return self.current_model.inference(image_path, prompt, **kwargs)
|
| 182 |
+
|
| 183 |
+
def get_current_model_status(self) -> str:
|
| 184 |
+
"""Get status string for the currently loaded model."""
|
| 185 |
+
if not self.current_model or not self.current_model.is_model_loaded():
|
| 186 |
+
return "β No model loaded"
|
| 187 |
+
|
| 188 |
+
quantization = self.current_model.current_quantization or "Unknown"
|
| 189 |
+
model_class = self.current_model.__class__.__name__
|
| 190 |
+
return f"β
{self.current_model_name} ({model_class}) loaded with {quantization}"
|
| 191 |
+
|
| 192 |
+
def get_supported_quantizations(self, model_name: str) -> list[str]:
|
| 193 |
+
"""Get supported quantization methods for a model."""
|
| 194 |
+
if model_name not in self.models:
|
| 195 |
+
raise KeyError(f"Model '{model_name}' not available")
|
| 196 |
+
|
| 197 |
+
return self.models[model_name].get_supported_quantizations()
|
| 198 |
+
|
| 199 |
+
def validate_model_and_quantization(self, model_name: str, quantization_type: str) -> bool:
|
| 200 |
+
"""
|
| 201 |
+
Validate if a model and quantization combination is valid.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
model_name: Name of the model
|
| 205 |
+
quantization_type: Type of quantization
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
True if valid, False otherwise
|
| 209 |
+
"""
|
| 210 |
+
if model_name not in self.models:
|
| 211 |
+
return False
|
| 212 |
+
|
| 213 |
+
return self.models[model_name].validate_quantization(quantization_type)
|
| 214 |
+
|
| 215 |
+
def get_model_memory_requirements(self, model_name: str) -> Dict[str, int]:
|
| 216 |
+
"""Get memory requirements for a specific model."""
|
| 217 |
+
if model_name not in self.models:
|
| 218 |
+
raise KeyError(f"Model '{model_name}' not available")
|
| 219 |
+
|
| 220 |
+
return self.models[model_name].get_memory_requirements()
|
| 221 |
+
|
| 222 |
+
def preload_default_model(self) -> bool:
|
| 223 |
+
"""
|
| 224 |
+
Preload the default model specified in configuration.
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
True if successful, False otherwise
|
| 228 |
+
"""
|
| 229 |
+
default_model = self.config_manager.get_default_model()
|
| 230 |
+
default_quantization = self.config_manager.get_default_quantization(default_model)
|
| 231 |
+
|
| 232 |
+
print(f"π Preloading default model: {default_model} with {default_quantization}")
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
return self.load_model(default_model, default_quantization)
|
| 236 |
+
except Exception as e:
|
| 237 |
+
print(f"β οΈ Failed to preload default model: {str(e)}")
|
| 238 |
+
return False
|
| 239 |
+
|
| 240 |
+
def __str__(self) -> str:
|
| 241 |
+
"""String representation of the model manager."""
|
| 242 |
+
loaded_info = f"Current: {self.current_model_name}" if self.current_model_name else "None loaded"
|
| 243 |
+
return f"ModelManager({len(self.models)} models available, {loaded_info})"
|
| 244 |
+
|
| 245 |
+
def __repr__(self) -> str:
|
| 246 |
+
"""Detailed string representation."""
|
| 247 |
+
models_list = list(self.models.keys())
|
| 248 |
+
return f"ModelManager(models={models_list}, current={self.current_model_name})"
|
backend/models/qwen/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .qwen_model import QwenModel
|
| 2 |
+
|
| 3 |
+
__all__ = ['QwenModel']
|
backend/models/qwen/qwen_model.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gc
|
| 3 |
+
import os
|
| 4 |
+
from typing import Dict, Any, Optional, Callable
|
| 5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 6 |
+
from ..base_model import BaseModel
|
| 7 |
+
from ...config.config_manager import ConfigManager
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class QwenModel(BaseModel):
|
| 11 |
+
"""Qwen2.5 model implementation."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, model_name: str, model_config: Dict[str, Any], config_manager: ConfigManager):
|
| 14 |
+
"""
|
| 15 |
+
Initialize the Qwen model.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
model_name: Name of the model
|
| 19 |
+
model_config: Configuration dictionary for the model
|
| 20 |
+
config_manager: Configuration manager instance
|
| 21 |
+
"""
|
| 22 |
+
super().__init__(model_name, model_config)
|
| 23 |
+
self.config_manager = config_manager
|
| 24 |
+
|
| 25 |
+
# Set environment variable for CUDA memory allocation
|
| 26 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
| 27 |
+
|
| 28 |
+
def check_model_exists_locally(self) -> bool:
|
| 29 |
+
"""Check if model exists locally in Hugging Face cache."""
|
| 30 |
+
try:
|
| 31 |
+
from transformers.utils import cached_file
|
| 32 |
+
cached_file(self.model_id, "config.json", local_files_only=True)
|
| 33 |
+
return True
|
| 34 |
+
except:
|
| 35 |
+
return False
|
| 36 |
+
|
| 37 |
+
def download_model_with_progress(self, progress_callback: Optional[Callable] = None) -> bool:
|
| 38 |
+
"""
|
| 39 |
+
Download model with progress tracking.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
progress_callback: Callback function for progress updates
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
True if successful, False otherwise
|
| 46 |
+
"""
|
| 47 |
+
try:
|
| 48 |
+
if progress_callback:
|
| 49 |
+
progress_callback("π₯ Downloading tokenizer...")
|
| 50 |
+
|
| 51 |
+
# Download tokenizer first (smaller)
|
| 52 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| 53 |
+
|
| 54 |
+
if progress_callback:
|
| 55 |
+
progress_callback("π₯ Downloading model weights... This may take several minutes...")
|
| 56 |
+
|
| 57 |
+
# Download model config and weights by trying to load config
|
| 58 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 59 |
+
self.model_id,
|
| 60 |
+
torch_dtype="auto",
|
| 61 |
+
device_map="cpu", # Just download, don't load to GPU yet
|
| 62 |
+
low_cpu_mem_usage=True
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Clean up the test loading
|
| 66 |
+
del model
|
| 67 |
+
|
| 68 |
+
if progress_callback:
|
| 69 |
+
progress_callback("β
Model downloaded successfully!")
|
| 70 |
+
|
| 71 |
+
return True
|
| 72 |
+
except Exception as e:
|
| 73 |
+
if progress_callback:
|
| 74 |
+
progress_callback(f"β Download failed: {str(e)}")
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
def load_model(self, quantization_type: str, progress_callback: Optional[Callable] = None) -> bool:
|
| 78 |
+
"""
|
| 79 |
+
Load the model with specified quantization.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
quantization_type: Type of quantization to use
|
| 83 |
+
progress_callback: Callback function for progress updates
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
True if successful, False otherwise
|
| 87 |
+
"""
|
| 88 |
+
if not self.validate_quantization(quantization_type):
|
| 89 |
+
raise ValueError(f"Quantization type '{quantization_type}' not supported for {self.model_name}")
|
| 90 |
+
|
| 91 |
+
# If model is already loaded with the same quantization, return
|
| 92 |
+
if (self.model is not None and self.tokenizer is not None and
|
| 93 |
+
self.current_quantization == quantization_type):
|
| 94 |
+
if progress_callback:
|
| 95 |
+
progress_callback(f"β
{self.model_name} already loaded!")
|
| 96 |
+
return True
|
| 97 |
+
|
| 98 |
+
print(f"Loading {self.model_name} with {quantization_type} quantization...")
|
| 99 |
+
if progress_callback:
|
| 100 |
+
progress_callback(f"π Loading {self.model_name} with {quantization_type} quantization...")
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
# Check if model exists locally
|
| 104 |
+
model_exists = self.check_model_exists_locally()
|
| 105 |
+
if not model_exists:
|
| 106 |
+
if progress_callback:
|
| 107 |
+
progress_callback(f"π₯ {self.model_name} not found locally. Starting download...")
|
| 108 |
+
print(f"Model {self.model_name} not found locally. Starting download...")
|
| 109 |
+
success = self.download_model_with_progress(progress_callback)
|
| 110 |
+
if not success:
|
| 111 |
+
raise Exception(f"Failed to download {self.model_name}")
|
| 112 |
+
else:
|
| 113 |
+
if progress_callback:
|
| 114 |
+
progress_callback(f"β
{self.model_name} found locally.")
|
| 115 |
+
|
| 116 |
+
# Clear existing model if any
|
| 117 |
+
if self.model is not None:
|
| 118 |
+
self.unload_model()
|
| 119 |
+
|
| 120 |
+
# Print memory before loading
|
| 121 |
+
self._print_gpu_memory("before loading")
|
| 122 |
+
|
| 123 |
+
if progress_callback:
|
| 124 |
+
progress_callback(f"π Loading {self.model_name} tokenizer...")
|
| 125 |
+
|
| 126 |
+
# Load tokenizer
|
| 127 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| 128 |
+
|
| 129 |
+
# Load model based on quantization type
|
| 130 |
+
if progress_callback:
|
| 131 |
+
progress_callback(f"π Loading {self.model_name} model...")
|
| 132 |
+
|
| 133 |
+
if "non-quantized" in quantization_type:
|
| 134 |
+
# Load with auto dtype and device mapping
|
| 135 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 136 |
+
self.model_id,
|
| 137 |
+
torch_dtype="auto",
|
| 138 |
+
device_map="auto",
|
| 139 |
+
low_cpu_mem_usage=True
|
| 140 |
+
)
|
| 141 |
+
else: # quantized (8bit)
|
| 142 |
+
print("Loading with 8-bit quantization to reduce memory usage...")
|
| 143 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 144 |
+
self.model_id,
|
| 145 |
+
torch_dtype="auto",
|
| 146 |
+
load_in_8bit=True,
|
| 147 |
+
device_map="auto",
|
| 148 |
+
low_cpu_mem_usage=True
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Verify model and tokenizer are properly loaded
|
| 152 |
+
if self.model is None:
|
| 153 |
+
raise Exception(f"Model failed to load for {self.model_name}")
|
| 154 |
+
if self.tokenizer is None:
|
| 155 |
+
raise Exception(f"Tokenizer failed to load for {self.model_name}")
|
| 156 |
+
|
| 157 |
+
self.current_quantization = quantization_type
|
| 158 |
+
self.is_loaded = True
|
| 159 |
+
|
| 160 |
+
success_msg = f"β
{self.model_name} loaded successfully with {quantization_type} quantization!"
|
| 161 |
+
print(success_msg)
|
| 162 |
+
if progress_callback:
|
| 163 |
+
progress_callback(success_msg)
|
| 164 |
+
|
| 165 |
+
# Print GPU memory usage after loading
|
| 166 |
+
self._print_gpu_memory("after loading")
|
| 167 |
+
|
| 168 |
+
return True
|
| 169 |
+
|
| 170 |
+
except Exception as e:
|
| 171 |
+
error_msg = f"Failed to load model {self.model_name}: {str(e)}"
|
| 172 |
+
print(error_msg)
|
| 173 |
+
if progress_callback:
|
| 174 |
+
progress_callback(f"β {error_msg}")
|
| 175 |
+
|
| 176 |
+
# Reset on failure
|
| 177 |
+
self.unload_model()
|
| 178 |
+
raise Exception(error_msg)
|
| 179 |
+
|
| 180 |
+
def unload_model(self) -> None:
|
| 181 |
+
"""Unload the model from memory."""
|
| 182 |
+
if self.model is not None:
|
| 183 |
+
print("π§Ή Clearing model from memory...")
|
| 184 |
+
del self.model
|
| 185 |
+
self.model = None
|
| 186 |
+
|
| 187 |
+
if self.tokenizer is not None:
|
| 188 |
+
del self.tokenizer
|
| 189 |
+
self.tokenizer = None
|
| 190 |
+
|
| 191 |
+
self.current_quantization = None
|
| 192 |
+
self.is_loaded = False
|
| 193 |
+
|
| 194 |
+
# Clear GPU cache
|
| 195 |
+
if torch.cuda.is_available():
|
| 196 |
+
torch.cuda.empty_cache()
|
| 197 |
+
|
| 198 |
+
# Force garbage collection
|
| 199 |
+
gc.collect()
|
| 200 |
+
|
| 201 |
+
if torch.cuda.is_available():
|
| 202 |
+
torch.cuda.empty_cache() # Clear again after gc
|
| 203 |
+
|
| 204 |
+
print("β
Model unloaded successfully")
|
| 205 |
+
|
| 206 |
+
def inference(self, image_path: str, prompt: str, **kwargs) -> str:
|
| 207 |
+
"""
|
| 208 |
+
Perform inference with a text prompt.
|
| 209 |
+
Note: Qwen2.5 is a text-only model, so image_path is ignored.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
image_path: Path to the image file (ignored for text-only models)
|
| 213 |
+
prompt: Text prompt for the model
|
| 214 |
+
**kwargs: Additional inference parameters
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Model's text response
|
| 218 |
+
"""
|
| 219 |
+
if not self.is_loaded:
|
| 220 |
+
raise RuntimeError(f"Model {self.model_name} is not loaded. Call load_model() first.")
|
| 221 |
+
|
| 222 |
+
if not prompt or not prompt.strip():
|
| 223 |
+
return "Error: No prompt provided"
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
# Prepare messages for chat format
|
| 227 |
+
messages = [
|
| 228 |
+
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
| 229 |
+
{"role": "user", "content": prompt}
|
| 230 |
+
]
|
| 231 |
+
|
| 232 |
+
# Apply chat template
|
| 233 |
+
text = self.tokenizer.apply_chat_template(
|
| 234 |
+
messages,
|
| 235 |
+
tokenize=False,
|
| 236 |
+
add_generation_prompt=True
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Tokenize input
|
| 240 |
+
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
| 241 |
+
|
| 242 |
+
# Generate response
|
| 243 |
+
generated_ids = self.model.generate(
|
| 244 |
+
**model_inputs,
|
| 245 |
+
max_new_tokens=kwargs.get('max_new_tokens', 512),
|
| 246 |
+
do_sample=kwargs.get('do_sample', True),
|
| 247 |
+
temperature=kwargs.get('temperature', 0.7),
|
| 248 |
+
top_p=kwargs.get('top_p', 0.9),
|
| 249 |
+
pad_token_id=self.tokenizer.eos_token_id
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# Extract only the generated part (remove input tokens)
|
| 253 |
+
generated_ids = [
|
| 254 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
| 255 |
+
]
|
| 256 |
+
|
| 257 |
+
# Decode response
|
| 258 |
+
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 259 |
+
return response
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
error_msg = f"Error processing prompt: {str(e)}"
|
| 263 |
+
print(error_msg)
|
| 264 |
+
return error_msg
|
| 265 |
+
|
| 266 |
+
def _print_gpu_memory(self, stage: str) -> None:
|
| 267 |
+
"""Print GPU memory usage for debugging."""
|
| 268 |
+
if torch.cuda.is_available():
|
| 269 |
+
print(f"Memory {stage}:")
|
| 270 |
+
for i in range(torch.cuda.device_count()):
|
| 271 |
+
allocated = torch.cuda.memory_allocated(i) / 1024**3
|
| 272 |
+
reserved = torch.cuda.memory_reserved(i) / 1024**3
|
| 273 |
+
print(f"GPU {i}: Allocated {allocated:.2f} GB, Reserved {reserved:.2f} GB")
|
backend/utils/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .image_processing import (
|
| 2 |
+
build_transform,
|
| 3 |
+
find_closest_aspect_ratio,
|
| 4 |
+
dynamic_preprocess,
|
| 5 |
+
load_image
|
| 6 |
+
)
|
| 7 |
+
from .data_processing import (
|
| 8 |
+
extract_file_dict,
|
| 9 |
+
validate_data,
|
| 10 |
+
extract_binary_output
|
| 11 |
+
)
|
| 12 |
+
from .metrics import (
|
| 13 |
+
create_confusion_matrix_plot,
|
| 14 |
+
create_accuracy_table,
|
| 15 |
+
save_dataframe_to_csv
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
'build_transform',
|
| 20 |
+
'find_closest_aspect_ratio',
|
| 21 |
+
'dynamic_preprocess',
|
| 22 |
+
'load_image',
|
| 23 |
+
'extract_file_dict',
|
| 24 |
+
'validate_data',
|
| 25 |
+
'extract_binary_output',
|
| 26 |
+
'create_confusion_matrix_plot',
|
| 27 |
+
'create_accuracy_table',
|
| 28 |
+
'save_dataframe_to_csv'
|
| 29 |
+
]
|
backend/utils/data_processing.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Dict, List, Tuple, Union, Any
|
| 5 |
+
|
| 6 |
+
def extract_file_dict(folder_path: List[Path]) -> Dict[str, Path]:
|
| 7 |
+
"""
|
| 8 |
+
Extract file dictionary from folder path.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
folder_path: List of Path objects from Gradio file upload
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
Dictionary mapping filename to full path
|
| 15 |
+
"""
|
| 16 |
+
file_dict = {}
|
| 17 |
+
for file in folder_path:
|
| 18 |
+
filepath = file
|
| 19 |
+
filename = filepath.name.split("/")[-1]
|
| 20 |
+
file_dict[filename] = filepath
|
| 21 |
+
return file_dict
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def validate_data(file_dict: Dict[str, Path]) -> Tuple[Union[bool, str], str]:
|
| 25 |
+
"""
|
| 26 |
+
Validate the uploaded data structure.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
file_dict: Dictionary of filename to path mappings
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Tuple of (validation_result, message)
|
| 33 |
+
validation_result can be:
|
| 34 |
+
- True: Valid data with CSV
|
| 35 |
+
- False: Invalid data
|
| 36 |
+
- "no_csv": Valid but no CSV file
|
| 37 |
+
- "multiple_csv": Valid but multiple CSV files
|
| 38 |
+
"""
|
| 39 |
+
# Find CSV file
|
| 40 |
+
csv_files = [fname for fname in file_dict if fname.lower().endswith('.csv')]
|
| 41 |
+
|
| 42 |
+
# Find image files
|
| 43 |
+
image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']
|
| 44 |
+
image_files = [fname for fname in file_dict if any(fname.lower().endswith(ext) for ext in image_exts)]
|
| 45 |
+
|
| 46 |
+
if not image_files:
|
| 47 |
+
return False, "No image files found in the folder or subfolders"
|
| 48 |
+
|
| 49 |
+
# If no CSV or multiple CSVs, we'll proceed with file-based processing
|
| 50 |
+
if len(csv_files) == 0:
|
| 51 |
+
return "no_csv", "No CSV file found. Will extract data from file paths and names."
|
| 52 |
+
elif len(csv_files) > 1:
|
| 53 |
+
return "multiple_csv", "Multiple CSV files found. Will extract data from file paths and names."
|
| 54 |
+
|
| 55 |
+
# Check if single CSV has required columns
|
| 56 |
+
try:
|
| 57 |
+
df = pd.read_csv(file_dict[csv_files[0]])
|
| 58 |
+
if 'Ground Truth' not in df.columns:
|
| 59 |
+
return False, "CSV file does not contain 'Ground Truth' column"
|
| 60 |
+
if 'Image Name' not in df.columns:
|
| 61 |
+
return False, "CSV file does not contain 'Image Name' column"
|
| 62 |
+
except Exception as e:
|
| 63 |
+
return False, f"Error reading CSV file: {str(e)}"
|
| 64 |
+
|
| 65 |
+
return True, "Data validation successful"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def extract_binary_output(
|
| 69 |
+
model_output: str,
|
| 70 |
+
ground_truth: str = "",
|
| 71 |
+
all_ground_truths: List[str] = None
|
| 72 |
+
) -> str:
|
| 73 |
+
"""
|
| 74 |
+
Extract binary output from model response based on unique ground truth keywords.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
model_output: The model's text response
|
| 78 |
+
ground_truth: Current item's ground truth (for fallback)
|
| 79 |
+
all_ground_truths: List of all ground truth values to extract unique keywords
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Extracted keyword that best matches the model output
|
| 83 |
+
"""
|
| 84 |
+
if all_ground_truths is None:
|
| 85 |
+
all_ground_truths = []
|
| 86 |
+
|
| 87 |
+
# Unique lowercase keywords
|
| 88 |
+
unique_keywords = sorted({str(gt).strip().lower() for gt in all_ground_truths if gt})
|
| 89 |
+
|
| 90 |
+
# Take only the first line of model output
|
| 91 |
+
first_line = model_output.split("\n", 1)[0].lower()
|
| 92 |
+
|
| 93 |
+
print(f"DEBUG: Unique keywords extracted: {first_line}")
|
| 94 |
+
print(f"DEBUG: Model output: {model_output[:100]}...") # First 100 chars
|
| 95 |
+
|
| 96 |
+
for keyword in unique_keywords:
|
| 97 |
+
if keyword in first_line:
|
| 98 |
+
return keyword
|
| 99 |
+
|
| 100 |
+
return "Enter the output manually"
|
backend/utils/image_processing.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torchvision.transforms as T
|
| 5 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 6 |
+
from typing import List, Tuple, Union
|
| 7 |
+
|
| 8 |
+
# Constants from InternVL preprocessing
|
| 9 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 10 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def build_transform(input_size: int = 448) -> T.Compose:
|
| 14 |
+
"""
|
| 15 |
+
Return torchvision transform matching InternVL preβtraining.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
input_size: Input image size (default: 448)
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
Composed torchvision transforms
|
| 22 |
+
"""
|
| 23 |
+
return T.Compose([
|
| 24 |
+
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
| 25 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
| 26 |
+
T.ToTensor(),
|
| 27 |
+
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
| 28 |
+
])
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def find_closest_aspect_ratio(
|
| 32 |
+
aspect_ratio: float,
|
| 33 |
+
target_ratios: List[Tuple[int, int]],
|
| 34 |
+
width: int,
|
| 35 |
+
height: int,
|
| 36 |
+
image_size: int
|
| 37 |
+
) -> Tuple[int, int]:
|
| 38 |
+
"""
|
| 39 |
+
Find the closest aspect ratio from target ratios.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
aspect_ratio: Current image aspect ratio
|
| 43 |
+
target_ratios: List of target aspect ratios as (width, height) tuples
|
| 44 |
+
width: Original image width
|
| 45 |
+
height: Original image height
|
| 46 |
+
image_size: Target image size
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Best matching aspect ratio as (width, height) tuple
|
| 50 |
+
"""
|
| 51 |
+
best_ratio_diff = float("inf")
|
| 52 |
+
best_ratio = (1, 1)
|
| 53 |
+
area = width * height
|
| 54 |
+
|
| 55 |
+
for ratio in target_ratios:
|
| 56 |
+
tgt_ar = ratio[0] / ratio[1]
|
| 57 |
+
diff = abs(aspect_ratio - tgt_ar)
|
| 58 |
+
|
| 59 |
+
if (diff < best_ratio_diff or
|
| 60 |
+
(diff == best_ratio_diff and area > 0.5 * image_size * image_size * ratio[0] * ratio[1])):
|
| 61 |
+
best_ratio_diff = diff
|
| 62 |
+
best_ratio = ratio
|
| 63 |
+
|
| 64 |
+
return best_ratio
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def dynamic_preprocess(
|
| 68 |
+
image: Image.Image,
|
| 69 |
+
min_num: int = 1,
|
| 70 |
+
max_num: int = 12,
|
| 71 |
+
image_size: int = 448,
|
| 72 |
+
use_thumbnail: bool = False
|
| 73 |
+
) -> List[Image.Image]:
|
| 74 |
+
"""
|
| 75 |
+
Split arbitrarilyβsized image into β€12 tiles sized 448Γ448 (InternVL spec).
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
image: Input PIL Image
|
| 79 |
+
min_num: Minimum number of tiles
|
| 80 |
+
max_num: Maximum number of tiles
|
| 81 |
+
image_size: Size of each tile (default: 448)
|
| 82 |
+
use_thumbnail: Whether to add a thumbnail version
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
List of processed image tiles
|
| 86 |
+
"""
|
| 87 |
+
ow, oh = image.size
|
| 88 |
+
aspect_ratio = ow / oh
|
| 89 |
+
|
| 90 |
+
# Generate target ratios
|
| 91 |
+
target_ratios = sorted(
|
| 92 |
+
{(i, j) for n in range(min_num, max_num + 1)
|
| 93 |
+
for i in range(1, n + 1)
|
| 94 |
+
for j in range(1, n + 1)
|
| 95 |
+
if min_num <= i * j <= max_num},
|
| 96 |
+
key=lambda x: x[0] * x[1],
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, ow, oh, image_size)
|
| 100 |
+
tw, th = image_size * ratio[0], image_size * ratio[1]
|
| 101 |
+
blocks = ratio[0] * ratio[1]
|
| 102 |
+
|
| 103 |
+
resized = image.resize((tw, th))
|
| 104 |
+
|
| 105 |
+
# Create tiles
|
| 106 |
+
tiles = []
|
| 107 |
+
for idx in range(blocks):
|
| 108 |
+
tile = resized.crop((
|
| 109 |
+
(idx % (tw // image_size)) * image_size,
|
| 110 |
+
(idx // (tw // image_size)) * image_size,
|
| 111 |
+
((idx % (tw // image_size)) + 1) * image_size,
|
| 112 |
+
((idx // (tw // image_size)) + 1) * image_size,
|
| 113 |
+
))
|
| 114 |
+
tiles.append(tile)
|
| 115 |
+
|
| 116 |
+
# Add thumbnail if requested and more than one tile
|
| 117 |
+
if use_thumbnail and blocks != 1:
|
| 118 |
+
tiles.append(image.resize((image_size, image_size)))
|
| 119 |
+
|
| 120 |
+
return tiles
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def load_image(
|
| 124 |
+
path: str,
|
| 125 |
+
input_size: int = 448,
|
| 126 |
+
max_num: int = 12
|
| 127 |
+
) -> torch.Tensor:
|
| 128 |
+
"""
|
| 129 |
+
Load and preprocess image for InternVL model.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
path: Path to the image file
|
| 133 |
+
input_size: Input image size (default: 448)
|
| 134 |
+
max_num: Maximum number of tiles (default: 12)
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Tensor of shape (N, 3, H, W) ready for InternVL
|
| 138 |
+
"""
|
| 139 |
+
img = Image.open(path).convert("RGB")
|
| 140 |
+
transform = build_transform(input_size)
|
| 141 |
+
tiles = dynamic_preprocess(
|
| 142 |
+
img,
|
| 143 |
+
image_size=input_size,
|
| 144 |
+
use_thumbnail=True,
|
| 145 |
+
max_num=max_num
|
| 146 |
+
)
|
| 147 |
+
return torch.stack([transform(t) for t in tiles])
|
backend/utils/metrics.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import seaborn as sns
|
| 5 |
+
import tempfile
|
| 6 |
+
from typing import Tuple, Optional
|
| 7 |
+
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_confusion_matrix_plot(
|
| 11 |
+
cm: np.ndarray,
|
| 12 |
+
accuracy: float,
|
| 13 |
+
labels: list = ['No', 'Yes']
|
| 14 |
+
) -> str:
|
| 15 |
+
"""
|
| 16 |
+
Create a confusion matrix plot and save it to a temporary file.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
cm: Confusion matrix array
|
| 20 |
+
accuracy: Accuracy score
|
| 21 |
+
labels: Labels for the confusion matrix
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Path to the saved plot file
|
| 25 |
+
"""
|
| 26 |
+
plt.figure(figsize=(6, 5))
|
| 27 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
|
| 28 |
+
plt.title(f'Confusion Matrix (Accuracy: {accuracy:.1%})')
|
| 29 |
+
plt.ylabel('Ground Truth')
|
| 30 |
+
plt.xlabel('Model Prediction')
|
| 31 |
+
|
| 32 |
+
temp_file = tempfile.mktemp(suffix='.png')
|
| 33 |
+
plt.savefig(temp_file, dpi=150, bbox_inches='tight')
|
| 34 |
+
plt.close()
|
| 35 |
+
|
| 36 |
+
return temp_file
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def create_accuracy_table(df: pd.DataFrame) -> Tuple[pd.DataFrame, str, pd.DataFrame]:
|
| 40 |
+
"""
|
| 41 |
+
Create accuracy metrics table and confusion matrix from results dataframe.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
df: DataFrame with 'Ground Truth' and 'Binary Output' columns
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Tuple of (metrics_df, confusion_matrix_plot_path, confusion_matrix_values_df)
|
| 48 |
+
|
| 49 |
+
Raises:
|
| 50 |
+
ValueError: If insufficient data for binary classification
|
| 51 |
+
"""
|
| 52 |
+
df_copy = df.copy()
|
| 53 |
+
|
| 54 |
+
# Get unique values from both Ground Truth and Binary Output
|
| 55 |
+
# Convert to string first, then apply .str operations
|
| 56 |
+
ground_truth_values = df_copy['Ground Truth'].dropna().astype(str).str.lower().unique()
|
| 57 |
+
binary_output_values = df_copy['Binary Output'].dropna().astype(str).str.lower().unique()
|
| 58 |
+
|
| 59 |
+
# Combine and get all unique values
|
| 60 |
+
all_values = set(list(ground_truth_values) + list(binary_output_values))
|
| 61 |
+
all_values = [v for v in all_values if v.strip()] # Remove empty strings
|
| 62 |
+
|
| 63 |
+
if len(all_values) < 2:
|
| 64 |
+
raise ValueError("Need at least 2 different values for binary classification")
|
| 65 |
+
|
| 66 |
+
# Sort values to ensure consistent mapping (alphabetical order)
|
| 67 |
+
sorted_values = sorted(all_values)
|
| 68 |
+
|
| 69 |
+
# Create mapping: first value (alphabetically) = 0, second = 1
|
| 70 |
+
# This ensures consistent mapping regardless of order in data
|
| 71 |
+
value_mapping = {sorted_values[0]: 0}
|
| 72 |
+
if len(sorted_values) >= 2:
|
| 73 |
+
value_mapping[sorted_values[1]] = 1
|
| 74 |
+
|
| 75 |
+
# If there are more than 2 values, map the rest to 1 (positive class)
|
| 76 |
+
for i in range(2, len(sorted_values)):
|
| 77 |
+
value_mapping[sorted_values[i]] = 1
|
| 78 |
+
|
| 79 |
+
print(f"Detected binary mapping: {value_mapping}")
|
| 80 |
+
|
| 81 |
+
# Apply mapping - convert to string first, then apply .str operations
|
| 82 |
+
df_copy['Ground Truth Binary'] = df_copy['Ground Truth'].astype(str).str.lower().map(value_mapping)
|
| 83 |
+
df_copy['Binary Output Binary'] = df_copy['Binary Output'].astype(str).str.lower().map(value_mapping)
|
| 84 |
+
|
| 85 |
+
# Remove rows where either ground truth or binary output is NaN
|
| 86 |
+
df_copy = df_copy.dropna(subset=['Ground Truth Binary', 'Binary Output Binary'])
|
| 87 |
+
|
| 88 |
+
if len(df_copy) == 0:
|
| 89 |
+
raise ValueError("No valid data for accuracy calculation after mapping. Check that Ground Truth and Binary Output contain valid binary values.")
|
| 90 |
+
|
| 91 |
+
# Calculate metrics
|
| 92 |
+
cm = confusion_matrix(df_copy['Ground Truth Binary'], df_copy['Binary Output Binary'])
|
| 93 |
+
accuracy = accuracy_score(df_copy['Ground Truth Binary'], df_copy['Binary Output Binary'])
|
| 94 |
+
precision = precision_score(df_copy['Ground Truth Binary'], df_copy['Binary Output Binary'], zero_division=0)
|
| 95 |
+
recall = recall_score(df_copy['Ground Truth Binary'], df_copy['Binary Output Binary'], zero_division=0)
|
| 96 |
+
f1 = f1_score(df_copy['Ground Truth Binary'], df_copy['Binary Output Binary'], zero_division=0)
|
| 97 |
+
|
| 98 |
+
# Create metrics dataframe
|
| 99 |
+
metrics_data = [
|
| 100 |
+
["Accuracy", f"{accuracy:.3f}"],
|
| 101 |
+
["Precision", f"{precision:.3f}"],
|
| 102 |
+
["Recall", f"{recall:.3f}"],
|
| 103 |
+
["F1 Score", f"{f1:.3f}"],
|
| 104 |
+
["Total Samples", f"{len(df_copy)}"]
|
| 105 |
+
]
|
| 106 |
+
metrics_df = pd.DataFrame(metrics_data, columns=["Metric", "Value"])
|
| 107 |
+
|
| 108 |
+
# Create labels for confusion matrix based on detected values
|
| 109 |
+
# Find the original case versions of the labels
|
| 110 |
+
original_labels = []
|
| 111 |
+
for mapped_val in sorted([k for k, v in value_mapping.items() if v in [0, 1]]):
|
| 112 |
+
# Find original case version from the data
|
| 113 |
+
original_case = None
|
| 114 |
+
for val in df_copy['Ground Truth'].dropna():
|
| 115 |
+
if str(val).lower() == mapped_val:
|
| 116 |
+
original_case = str(val)
|
| 117 |
+
break
|
| 118 |
+
if original_case is None:
|
| 119 |
+
for val in df_copy['Binary Output'].dropna():
|
| 120 |
+
if str(val).lower() == mapped_val:
|
| 121 |
+
original_case = str(val)
|
| 122 |
+
break
|
| 123 |
+
original_labels.append(original_case if original_case else mapped_val.title())
|
| 124 |
+
|
| 125 |
+
# Ensure we have exactly 2 labels
|
| 126 |
+
if len(original_labels) < 2:
|
| 127 |
+
original_labels = ['Class 0', 'Class 1']
|
| 128 |
+
|
| 129 |
+
cm_plot_path = create_confusion_matrix_plot(cm, accuracy, original_labels)
|
| 130 |
+
|
| 131 |
+
# Confusion matrix values table
|
| 132 |
+
if cm.shape == (2, 2):
|
| 133 |
+
tn, fp, fn, tp = cm.ravel()
|
| 134 |
+
cm_values = pd.DataFrame(
|
| 135 |
+
[[tn, fp], [fn, tp]],
|
| 136 |
+
columns=[f"Predicted {original_labels[0]}", f"Predicted {original_labels[1]}"],
|
| 137 |
+
index=[f"Actual {original_labels[0]}", f"Actual {original_labels[1]}"]
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
cm_values = pd.DataFrame(cm)
|
| 141 |
+
|
| 142 |
+
return metrics_df, cm_plot_path, cm_values
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def save_dataframe_to_csv(df: pd.DataFrame) -> Optional[str]:
|
| 146 |
+
"""
|
| 147 |
+
Save dataframe to a temporary CSV file.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
df: DataFrame to save
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Path to saved CSV file or None if failed
|
| 154 |
+
"""
|
| 155 |
+
if df is None or df.empty:
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
temp_file = tempfile.mktemp(suffix='.csv')
|
| 159 |
+
df.to_csv(temp_file, index=False)
|
| 160 |
+
return temp_file
|
config/models.yaml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Configuration for Vision Language Models and Language Models
|
| 2 |
+
# This file contains model configurations for easy integration
|
| 3 |
+
|
| 4 |
+
models:
|
| 5 |
+
# InternVL Vision-Language Models
|
| 6 |
+
InternVL3-8B:
|
| 7 |
+
name: "InternVL3-8B"
|
| 8 |
+
model_id: "OpenGVLab/InternVL3-8B"
|
| 9 |
+
model_type: "internvl"
|
| 10 |
+
description: "Fastest model, good for quick processing"
|
| 11 |
+
supported_quantizations:
|
| 12 |
+
- "non-quantized(fp16)"
|
| 13 |
+
- "quantized(8bit)"
|
| 14 |
+
default_quantization: "non-quantized(fp16)"
|
| 15 |
+
|
| 16 |
+
InternVL3-14B:
|
| 17 |
+
name: "InternVL3-14B"
|
| 18 |
+
model_id: "OpenGVLab/InternVL3-14B"
|
| 19 |
+
model_type: "internvl"
|
| 20 |
+
description: "Balanced performance and quality"
|
| 21 |
+
supported_quantizations:
|
| 22 |
+
- "non-quantized(fp16)"
|
| 23 |
+
- "quantized(8bit)"
|
| 24 |
+
default_quantization: "quantized(8bit)"
|
| 25 |
+
|
| 26 |
+
InternVL3-38B:
|
| 27 |
+
name: "InternVL3-38B"
|
| 28 |
+
model_id: "OpenGVLab/InternVL3-38B"
|
| 29 |
+
model_type: "internvl"
|
| 30 |
+
description: "Highest quality, requires significant GPU memory"
|
| 31 |
+
supported_quantizations:
|
| 32 |
+
- "non-quantized(fp16)"
|
| 33 |
+
- "quantized(8bit)"
|
| 34 |
+
default_quantization: "quantized(8bit)"
|
| 35 |
+
|
| 36 |
+
InternVL3_5-8B:
|
| 37 |
+
name: "InternVL3_5-8B"
|
| 38 |
+
model_id: "OpenGVLab/InternVL3_5-8B"
|
| 39 |
+
model_type: "internvl"
|
| 40 |
+
description: "Fastest model, good for quick processing"
|
| 41 |
+
supported_quantizations:
|
| 42 |
+
- "non-quantized(fp16)"
|
| 43 |
+
- "quantized(8bit)"
|
| 44 |
+
default_quantization: "non-quantized(fp16)"
|
| 45 |
+
|
| 46 |
+
# Qwen Language Models (Text-only)
|
| 47 |
+
Qwen2.5-7B-Instruct:
|
| 48 |
+
name: "Qwen2.5-7B-Instruct"
|
| 49 |
+
model_id: "Qwen/Qwen2.5-7B-Instruct"
|
| 50 |
+
model_type: "qwen"
|
| 51 |
+
description: "Qwen2.5 7B instruction-tuned model for text generation"
|
| 52 |
+
supported_quantizations:
|
| 53 |
+
- "non-quantized(fp16)"
|
| 54 |
+
- "quantized(8bit)"
|
| 55 |
+
default_quantization: "quantized(8bit)"
|
| 56 |
+
|
| 57 |
+
Qwen2.5-14B-Instruct:
|
| 58 |
+
name: "Qwen2.5-14B-Instruct"
|
| 59 |
+
model_id: "Qwen/Qwen2.5-14B-Instruct"
|
| 60 |
+
model_type: "qwen"
|
| 61 |
+
description: "Qwen2.5 14B instruction-tuned model for better text generation"
|
| 62 |
+
supported_quantizations:
|
| 63 |
+
- "non-quantized(fp16)"
|
| 64 |
+
- "quantized(8bit)"
|
| 65 |
+
default_quantization: "quantized(8bit)"
|
| 66 |
+
|
| 67 |
+
# Default model selection
|
| 68 |
+
default_model: "InternVL3-8B"
|
debug_files.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
File Upload Diagnostic Script
|
| 4 |
+
This script helps debug why some images are not being processed.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List
|
| 10 |
+
|
| 11 |
+
def analyze_uploaded_files(folder_path: str) -> None:
|
| 12 |
+
"""
|
| 13 |
+
Analyze uploaded files to understand why some images might not be processed.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
folder_path: Path to the uploaded folder
|
| 17 |
+
"""
|
| 18 |
+
print("π File Upload Diagnostic Tool")
|
| 19 |
+
print("=" * 50)
|
| 20 |
+
|
| 21 |
+
if not os.path.exists(folder_path):
|
| 22 |
+
print(f"β Folder not found: {folder_path}")
|
| 23 |
+
return
|
| 24 |
+
|
| 25 |
+
# Get all files in the folder
|
| 26 |
+
all_files = []
|
| 27 |
+
for root, dirs, files in os.walk(folder_path):
|
| 28 |
+
for file in files:
|
| 29 |
+
full_path = os.path.join(root, file)
|
| 30 |
+
all_files.append(Path(full_path))
|
| 31 |
+
|
| 32 |
+
print(f"π Total files found: {len(all_files)}")
|
| 33 |
+
print("\nπ All files:")
|
| 34 |
+
for i, file_path in enumerate(all_files, 1):
|
| 35 |
+
print(f" {i}. {file_path.name} (ext: {file_path.suffix.lower()})")
|
| 36 |
+
|
| 37 |
+
# Analyze image files
|
| 38 |
+
image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']
|
| 39 |
+
print(f"\nπΌοΈ Looking for image extensions: {image_exts}")
|
| 40 |
+
|
| 41 |
+
image_files = []
|
| 42 |
+
non_image_files = []
|
| 43 |
+
|
| 44 |
+
for file_path in all_files:
|
| 45 |
+
if any(file_path.suffix.lower().endswith(ext) for ext in image_exts):
|
| 46 |
+
image_files.append(file_path)
|
| 47 |
+
else:
|
| 48 |
+
non_image_files.append(file_path)
|
| 49 |
+
|
| 50 |
+
print(f"\nβ
Image files detected ({len(image_files)}):")
|
| 51 |
+
for i, img in enumerate(image_files, 1):
|
| 52 |
+
print(f" {i}. {img.name}")
|
| 53 |
+
|
| 54 |
+
print(f"\nπ Non-image files ({len(non_image_files)}):")
|
| 55 |
+
for i, file in enumerate(non_image_files, 1):
|
| 56 |
+
print(f" {i}. {file.name} (ext: {file.suffix.lower()})")
|
| 57 |
+
|
| 58 |
+
# Check for CSV files
|
| 59 |
+
csv_files = [f for f in all_files if f.suffix.lower() == '.csv']
|
| 60 |
+
print(f"\nπ CSV files found ({len(csv_files)}):")
|
| 61 |
+
for i, csv in enumerate(csv_files, 1):
|
| 62 |
+
print(f" {i}. {csv.name}")
|
| 63 |
+
|
| 64 |
+
# If CSV exists, check its content
|
| 65 |
+
if csv_files:
|
| 66 |
+
try:
|
| 67 |
+
import pandas as pd
|
| 68 |
+
df = pd.read_csv(csv_files[0])
|
| 69 |
+
print(f"\nπ CSV Analysis for '{csv_files[0].name}':")
|
| 70 |
+
print(f" - Rows: {len(df)}")
|
| 71 |
+
print(f" - Columns: {list(df.columns)}")
|
| 72 |
+
|
| 73 |
+
if 'Image Name' in df.columns:
|
| 74 |
+
image_names_in_csv = df['Image Name'].tolist()
|
| 75 |
+
print(f" - Image names in CSV: {len(image_names_in_csv)}")
|
| 76 |
+
|
| 77 |
+
# Check which images from CSV actually exist as files
|
| 78 |
+
existing_images = []
|
| 79 |
+
missing_images = []
|
| 80 |
+
|
| 81 |
+
for img_name in image_names_in_csv:
|
| 82 |
+
if any(img.name == img_name for img in image_files):
|
| 83 |
+
existing_images.append(img_name)
|
| 84 |
+
else:
|
| 85 |
+
missing_images.append(img_name)
|
| 86 |
+
|
| 87 |
+
print(f"\nπ CSV-to-File Matching:")
|
| 88 |
+
print(f" - Images in CSV that exist as files: {len(existing_images)}")
|
| 89 |
+
print(f" - Images in CSV that are missing: {len(missing_images)}")
|
| 90 |
+
|
| 91 |
+
if existing_images:
|
| 92 |
+
print(" β
Matching files:")
|
| 93 |
+
for img in existing_images:
|
| 94 |
+
print(f" - {img}")
|
| 95 |
+
|
| 96 |
+
if missing_images:
|
| 97 |
+
print(" β Missing files:")
|
| 98 |
+
for img in missing_images:
|
| 99 |
+
print(f" - {img}")
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f" β Error reading CSV: {e}")
|
| 103 |
+
|
| 104 |
+
# Summary
|
| 105 |
+
print(f"\nπ SUMMARY:")
|
| 106 |
+
print(f" - Total files uploaded: {len(all_files)}")
|
| 107 |
+
print(f" - Image files detected: {len(image_files)}")
|
| 108 |
+
print(f" - CSV files: {len(csv_files)}")
|
| 109 |
+
|
| 110 |
+
if csv_files and 'df' in locals():
|
| 111 |
+
if 'Image Name' in df.columns:
|
| 112 |
+
print(f" - Images that will be processed: {len(existing_images)}")
|
| 113 |
+
else:
|
| 114 |
+
print(f" - CSV exists but no 'Image Name' column - will process all {len(image_files)} images")
|
| 115 |
+
else:
|
| 116 |
+
print(f" - No CSV - will process all {len(image_files)} images")
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
print("Please provide the path to your uploaded folder:")
|
| 120 |
+
folder_path = input("Folder path: ").strip()
|
| 121 |
+
analyze_uploaded_files(folder_path)
|
frontend/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .gradio_app import GradioApp
|
| 2 |
+
|
| 3 |
+
__all__ = ['GradioApp']
|
frontend/gradio_app.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import shutil
|
| 6 |
+
import tempfile
|
| 7 |
+
import uuid
|
| 8 |
+
import spaces
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
from backend import ConfigManager, ModelManager, InferenceEngine
|
| 12 |
+
from backend.utils.metrics import create_accuracy_table, save_dataframe_to_csv
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GradioApp:
|
| 16 |
+
"""Gradio application for InternVL3 prompt engineering."""
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
"""Initialize the Gradio application."""
|
| 20 |
+
# Initialize backend components
|
| 21 |
+
self.config_manager = ConfigManager()
|
| 22 |
+
self.model_manager = ModelManager(self.config_manager)
|
| 23 |
+
self.inference_engine = InferenceEngine(self.model_manager, self.config_manager)
|
| 24 |
+
|
| 25 |
+
# Try to preload default model
|
| 26 |
+
try:
|
| 27 |
+
self.model_manager.preload_default_model()
|
| 28 |
+
print("β
Default model preloaded successfully!")
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print(f"β οΈ Default model preloading failed: {str(e)}")
|
| 31 |
+
print("The model will be loaded when first needed.")
|
| 32 |
+
|
| 33 |
+
def get_current_model_status(self) -> str:
|
| 34 |
+
"""Get current model status for display."""
|
| 35 |
+
return self.model_manager.get_current_model_status()
|
| 36 |
+
|
| 37 |
+
def handle_stop_button(self):
|
| 38 |
+
"""Handle stop button click."""
|
| 39 |
+
message = self.inference_engine.set_stop_flag()
|
| 40 |
+
return message, gr.update(visible=True)
|
| 41 |
+
|
| 42 |
+
def on_model_change(self, model_selection: str, quantization_type: str) -> str:
|
| 43 |
+
"""Handle model/quantization dropdown changes."""
|
| 44 |
+
current_status = self.get_current_model_status()
|
| 45 |
+
if model_selection and quantization_type:
|
| 46 |
+
available_models = self.config_manager.get_available_models()
|
| 47 |
+
target_id = available_models.get(model_selection)
|
| 48 |
+
current_model_id = None
|
| 49 |
+
if self.model_manager.current_model:
|
| 50 |
+
current_model_id = self.model_manager.current_model.model_id
|
| 51 |
+
|
| 52 |
+
if (current_model_id != target_id or
|
| 53 |
+
(self.model_manager.current_model and
|
| 54 |
+
self.model_manager.current_model.current_quantization != quantization_type)):
|
| 55 |
+
return f"π Will load {model_selection} with {quantization_type} when processing starts"
|
| 56 |
+
return current_status
|
| 57 |
+
|
| 58 |
+
def get_model_choices_with_info(self) -> list[str]:
|
| 59 |
+
"""Get model choices with type information for dropdown."""
|
| 60 |
+
choices = []
|
| 61 |
+
for model_name in self.config_manager.get_available_models().keys():
|
| 62 |
+
model_config = self.config_manager.get_model_config(model_name)
|
| 63 |
+
model_type = model_config.get('model_type', 'unknown').upper()
|
| 64 |
+
choices.append(f"{model_name} ({model_type})")
|
| 65 |
+
return choices
|
| 66 |
+
|
| 67 |
+
def extract_model_name_from_choice(self, choice: str) -> str:
|
| 68 |
+
"""Extract the actual model name from the dropdown choice."""
|
| 69 |
+
return choice.split(' (')[0] if ' (' in choice else choice
|
| 70 |
+
|
| 71 |
+
def update_image_preview(self, evt: gr.SelectData, df, folder_path):
|
| 72 |
+
"""Update image preview when table row is selected."""
|
| 73 |
+
if df is None or evt.index[0] >= len(df):
|
| 74 |
+
return None, ""
|
| 75 |
+
try:
|
| 76 |
+
# Use the full dataframe with image paths
|
| 77 |
+
full_df = getattr(self.inference_engine, 'full_df', None)
|
| 78 |
+
if full_df is None or evt.index[0] >= len(full_df):
|
| 79 |
+
return None, ""
|
| 80 |
+
selected_row = full_df.iloc[evt.index[0]]
|
| 81 |
+
image_path = selected_row["Image Path"]
|
| 82 |
+
model_output = selected_row["Model Output"]
|
| 83 |
+
if not os.path.exists(image_path):
|
| 84 |
+
return None, model_output
|
| 85 |
+
file_extension = Path(image_path).suffix
|
| 86 |
+
temp_filename = f"gradio_preview_{uuid.uuid4().hex}{file_extension}"
|
| 87 |
+
temp_path = os.path.join(tempfile.gettempdir(), temp_filename)
|
| 88 |
+
shutil.copy2(image_path, temp_path)
|
| 89 |
+
return temp_path, model_output
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"Error loading image preview: {e}")
|
| 92 |
+
return None, ""
|
| 93 |
+
|
| 94 |
+
def download_results_csv(self, results_table_data):
|
| 95 |
+
"""Download results as CSV file."""
|
| 96 |
+
try:
|
| 97 |
+
print(f"Download function called with data type: {type(results_table_data)}")
|
| 98 |
+
|
| 99 |
+
if results_table_data is None:
|
| 100 |
+
print("No data to download")
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
# Handle different data types from Gradio
|
| 104 |
+
if hasattr(results_table_data, 'values'):
|
| 105 |
+
# If it's a pandas DataFrame
|
| 106 |
+
df = results_table_data
|
| 107 |
+
elif isinstance(results_table_data, list):
|
| 108 |
+
# If it's a list of lists or list of dicts
|
| 109 |
+
if len(results_table_data) == 0:
|
| 110 |
+
print("Empty data")
|
| 111 |
+
return None
|
| 112 |
+
df = pd.DataFrame(results_table_data, columns=["S.No", "Image Name", "Ground Truth", "Binary Output", "Model Output"])
|
| 113 |
+
else:
|
| 114 |
+
# Try to convert to DataFrame
|
| 115 |
+
df = pd.DataFrame(results_table_data)
|
| 116 |
+
|
| 117 |
+
print(f"DataFrame shape: {df.shape}")
|
| 118 |
+
print(f"DataFrame columns: {df.columns.tolist()}")
|
| 119 |
+
|
| 120 |
+
# Create temporary file
|
| 121 |
+
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False)
|
| 122 |
+
df.to_csv(temp_file.name, index=False)
|
| 123 |
+
temp_file.close()
|
| 124 |
+
|
| 125 |
+
print(f"CSV file created: {temp_file.name}")
|
| 126 |
+
return temp_file.name
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"Error in download_results_csv: {str(e)}")
|
| 130 |
+
import traceback
|
| 131 |
+
traceback.print_exc()
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
def submit_and_show_metrics(self, df):
|
| 135 |
+
"""Generate and show metrics for results."""
|
| 136 |
+
if df is None:
|
| 137 |
+
return df, df, None, None, None, gr.update(visible=False), gr.update(visible=False), ""
|
| 138 |
+
|
| 139 |
+
# Only create metrics if all outputs are valid yes/no responses
|
| 140 |
+
try:
|
| 141 |
+
metrics_df, cm_plot_path, cm_values = create_accuracy_table(df)
|
| 142 |
+
return df, df, metrics_df, cm_plot_path, cm_values, gr.update(visible=True), gr.update(visible=True), "π Metrics calculated successfully!"
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"Could not create metrics: {str(e)}")
|
| 145 |
+
return df, df, None, None, None, gr.update(visible=False), gr.update(visible=True), f"β οΈ Could not calculate metrics: {str(e)}"
|
| 146 |
+
|
| 147 |
+
@spaces.GPU
|
| 148 |
+
def process_input_ui(self, folder_path, prompt, quantization_type, model_selection):
|
| 149 |
+
"""UI wrapper for processing input with progress updates."""
|
| 150 |
+
if not folder_path or not prompt.strip():
|
| 151 |
+
return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False),
|
| 152 |
+
"Please upload a folder and enter a prompt.", None, None, None,
|
| 153 |
+
gr.update(visible=False), gr.update(visible=False),
|
| 154 |
+
gr.update(value="β οΈ Please upload a folder and enter a prompt.", visible=True), "", gr.update(visible=False))
|
| 155 |
+
|
| 156 |
+
# Extract actual model name from the dropdown choice
|
| 157 |
+
actual_model_name = self.extract_model_name_from_choice(model_selection)
|
| 158 |
+
|
| 159 |
+
# Check if model needs to be downloaded and show progress
|
| 160 |
+
available_models = self.config_manager.get_available_models()
|
| 161 |
+
model_id = available_models[actual_model_name]
|
| 162 |
+
|
| 163 |
+
# Show processing message and hide stop status
|
| 164 |
+
yield (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 165 |
+
None, None, None, None,
|
| 166 |
+
gr.update(visible=False), gr.update(visible=False),
|
| 167 |
+
gr.update(value="π Initializing processing...", visible=True), prompt, gr.update(visible=False))
|
| 168 |
+
|
| 169 |
+
# Process the input
|
| 170 |
+
error, show_results, show_image, table, error_message, final_message = self.inference_engine.process_folder_input(
|
| 171 |
+
folder_path, prompt, quantization_type, actual_model_name, gr.Progress()
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# If error is visible, show results section but keep error visible
|
| 175 |
+
if error["visible"]:
|
| 176 |
+
yield (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True),
|
| 177 |
+
error, None, None, None,
|
| 178 |
+
gr.update(visible=False), gr.update(visible=False),
|
| 179 |
+
gr.update(value=final_message, visible=True), prompt, gr.update(visible=False))
|
| 180 |
+
else:
|
| 181 |
+
yield (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True),
|
| 182 |
+
None, show_results, show_image, table,
|
| 183 |
+
gr.update(visible=True), gr.update(visible=False),
|
| 184 |
+
gr.update(value=final_message, visible=True), prompt, gr.update(visible=False))
|
| 185 |
+
|
| 186 |
+
def rerun_ui(self, df, new_prompt, quantization_type, model_selection):
|
| 187 |
+
"""UI wrapper for rerun with progress updates."""
|
| 188 |
+
if df is None or not new_prompt.strip():
|
| 189 |
+
return (df, None, None, None,
|
| 190 |
+
gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 191 |
+
gr.update(visible=False), gr.update(visible=True), "β οΈ Please provide a valid prompt", "")
|
| 192 |
+
|
| 193 |
+
# Extract actual model name from the dropdown choice
|
| 194 |
+
actual_model_name = self.extract_model_name_from_choice(model_selection)
|
| 195 |
+
|
| 196 |
+
# Hide all sections and show only processing, clear model output display
|
| 197 |
+
yield (df, None, None, None,
|
| 198 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 199 |
+
gr.update(visible=False), gr.update(visible=True), "π Initializing reprocessing...", "Select a row from the table to see model output...")
|
| 200 |
+
|
| 201 |
+
# Process with new prompt
|
| 202 |
+
updated_df, accuracy_table_data, cm_plot, cm_values, section4_vis, progress_vis, final_message = self.inference_engine.rerun_with_new_prompt(
|
| 203 |
+
df, new_prompt, quantization_type, actual_model_name, gr.Progress()
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Show prompt editing and results sections again, show Generate Metrics button, hide progress, and clear model output display
|
| 207 |
+
yield (updated_df, accuracy_table_data, cm_plot, cm_values,
|
| 208 |
+
gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), section4_vis,
|
| 209 |
+
gr.update(visible=True), gr.update(visible=False), final_message, "Select a row from the table to see updated model output...")
|
| 210 |
+
|
| 211 |
+
def create_interface(self):
|
| 212 |
+
"""Create and return the Gradio interface."""
|
| 213 |
+
# CSS from original app.py
|
| 214 |
+
css = """
|
| 215 |
+
.progress {
|
| 216 |
+
margin: 15px 0;
|
| 217 |
+
padding: 20px;
|
| 218 |
+
border-radius: 12px;
|
| 219 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 220 |
+
border: none;
|
| 221 |
+
color: white;
|
| 222 |
+
font-weight: 600;
|
| 223 |
+
font-size: 16px;
|
| 224 |
+
text-align: center;
|
| 225 |
+
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3);
|
| 226 |
+
animation: progressPulse 2s ease-in-out infinite alternate;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
@keyframes progressPulse {
|
| 230 |
+
0% {
|
| 231 |
+
transform: scale(1);
|
| 232 |
+
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3);
|
| 233 |
+
}
|
| 234 |
+
100% {
|
| 235 |
+
transform: scale(1.02);
|
| 236 |
+
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.4);
|
| 237 |
+
}
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
.processing {
|
| 241 |
+
background: linear-gradient(45deg, #f0f9ff, #e3f2fd);
|
| 242 |
+
border: 2px solid #1976d2;
|
| 243 |
+
border-radius: 10px;
|
| 244 |
+
padding: 20px;
|
| 245 |
+
text-align: center;
|
| 246 |
+
margin: 10px 0;
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
.gr-button.processing {
|
| 250 |
+
background-color: #ffa726 !important;
|
| 251 |
+
color: white !important;
|
| 252 |
+
pointer-events: none;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
/* Stop button styling */
|
| 256 |
+
.stop-button {
|
| 257 |
+
background: linear-gradient(135deg, #ff4757 0%, #c44569 100%) !important;
|
| 258 |
+
border: none !important;
|
| 259 |
+
color: white !important;
|
| 260 |
+
font-weight: 700 !important;
|
| 261 |
+
font-size: 16px !important;
|
| 262 |
+
box-shadow: 0 4px 15px rgba(255, 71, 87, 0.4) !important;
|
| 263 |
+
transition: all 0.3s ease !important;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
.stop-button:hover {
|
| 267 |
+
transform: translateY(-2px) !important;
|
| 268 |
+
box-shadow: 0 8px 25px rgba(255, 71, 87, 0.6) !important;
|
| 269 |
+
background: linear-gradient(135deg, #ff3742 0%, #b83754 100%) !important;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
.stop-status {
|
| 273 |
+
color: #ff4757;
|
| 274 |
+
font-weight: 600;
|
| 275 |
+
background: rgba(255, 71, 87, 0.1);
|
| 276 |
+
padding: 10px;
|
| 277 |
+
border-radius: 8px;
|
| 278 |
+
border-left: 4px solid #ff4757;
|
| 279 |
+
margin: 10px 0;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
/* Enhanced button styling */
|
| 283 |
+
.gr-button {
|
| 284 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 285 |
+
border: none;
|
| 286 |
+
border-radius: 8px;
|
| 287 |
+
color: white;
|
| 288 |
+
font-weight: 600;
|
| 289 |
+
transition: all 0.3s ease;
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
.gr-button:hover {
|
| 293 |
+
transform: translateY(-2px);
|
| 294 |
+
box-shadow: 0 8px 25px rgba(102, 126, 234, 0.4);
|
| 295 |
+
}
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
with gr.Blocks(theme="origin", css=css) as demo:
|
| 299 |
+
gr.Markdown("""
|
| 300 |
+
<h1 style='text-align:center; color:#1976d2; font-size:2.5em; font-weight:bold; margin-bottom:40px!important;'>PROMPT_PILOT</h1>
|
| 301 |
+
<p style='text-align:center; color:#666; font-size:1.1em; margin-bottom:30px;'>
|
| 302 |
+
π€ AI-powered analysis with different vision models
|
| 303 |
+
</p>
|
| 304 |
+
<h2 style='text-align:center; color:#666; font-size:1.1em; margin-bottom:30px;'>
|
| 305 |
+
Note: Currently Accuracy only works properly in case of binary output. For other cases kindly download the csv and calculate the accuracy separately.
|
| 306 |
+
</h2>
|
| 307 |
+
""", elem_id="main-title")
|
| 308 |
+
|
| 309 |
+
# Model and Quantization selection dropdowns at the top
|
| 310 |
+
model_choices = self.get_model_choices_with_info()
|
| 311 |
+
default_choice = f"{self.config_manager.get_default_model()} (INTERNVL)"
|
| 312 |
+
|
| 313 |
+
with gr.Row():
|
| 314 |
+
model_dropdown = gr.Dropdown(
|
| 315 |
+
choices=model_choices,
|
| 316 |
+
value=default_choice,
|
| 317 |
+
label="π€ Model Selection",
|
| 318 |
+
info="Select model: InternVL (vision+text), Qwen (text-only)",
|
| 319 |
+
elem_id="model-dropdown"
|
| 320 |
+
)
|
| 321 |
+
quantization_dropdown = gr.Dropdown(
|
| 322 |
+
choices=["quantized(8bit)", "non-quantized(fp16)"],
|
| 323 |
+
value="non-quantized(fp16)",
|
| 324 |
+
label="π§ Model Quantization",
|
| 325 |
+
info="Select quantization type: quantized (8bit) uses less memory, non-quantized (fp16) for better quality",
|
| 326 |
+
elem_id="quantization-dropdown"
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# Model status indicator
|
| 330 |
+
with gr.Row():
|
| 331 |
+
model_status = gr.Markdown(
|
| 332 |
+
value=self.get_current_model_status(),
|
| 333 |
+
label="Model Status",
|
| 334 |
+
elem_classes=["model-status"]
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Stop button row
|
| 338 |
+
with gr.Row():
|
| 339 |
+
stop_btn = gr.Button("π STOP PROCESSING", variant="stop", size="lg", elem_classes=["stop-button"])
|
| 340 |
+
stop_status = gr.Markdown("", elem_classes=["stop-status"], visible=False)
|
| 341 |
+
|
| 342 |
+
with gr.Row(visible=True) as section1_row:
|
| 343 |
+
with gr.Column():
|
| 344 |
+
folder_input = gr.File(
|
| 345 |
+
label="Upload Folder",
|
| 346 |
+
file_count="directory",
|
| 347 |
+
type="filepath"
|
| 348 |
+
)
|
| 349 |
+
with gr.Column():
|
| 350 |
+
prompt_input = gr.Textbox(
|
| 351 |
+
label="Enter your prompt here",
|
| 352 |
+
placeholder="Type your prompt...",
|
| 353 |
+
lines=3
|
| 354 |
+
)
|
| 355 |
+
with gr.Column():
|
| 356 |
+
submit_btn = gr.Button("Proceed", variant="primary")
|
| 357 |
+
|
| 358 |
+
# Progress indicator for section 1
|
| 359 |
+
with gr.Row(visible=True) as section1_progress_row:
|
| 360 |
+
section1_progress_message = gr.Markdown("", elem_classes=["progress"], visible=False)
|
| 361 |
+
|
| 362 |
+
# Section 2: Edit Prompt and Rerun Controls (separate section)
|
| 363 |
+
with gr.Row(visible=False) as section2_prompt_row:
|
| 364 |
+
with gr.Column():
|
| 365 |
+
with gr.Row():
|
| 366 |
+
prompt_input_section2 = gr.Textbox(
|
| 367 |
+
label="Edit Prompt",
|
| 368 |
+
placeholder="Modify your prompt here...",
|
| 369 |
+
lines=2,
|
| 370 |
+
scale=4
|
| 371 |
+
)
|
| 372 |
+
rerun_btn = gr.Button("π Rerun", variant="secondary", size="lg", scale=1)
|
| 373 |
+
|
| 374 |
+
# Section 3: Results Display
|
| 375 |
+
with gr.Row(visible=False) as section3_results_row:
|
| 376 |
+
error_message = gr.Textbox(label="Error Message", visible=False)
|
| 377 |
+
with gr.Column(scale=1):
|
| 378 |
+
image_preview = gr.Image(label="Selected Image", height=270, width=480)
|
| 379 |
+
model_output_display = gr.Textbox(
|
| 380 |
+
label="Model Output for Selected Image",
|
| 381 |
+
placeholder="Select a row from the table to see model output...",
|
| 382 |
+
interactive=False,
|
| 383 |
+
lines=3
|
| 384 |
+
)
|
| 385 |
+
with gr.Column(scale=2):
|
| 386 |
+
with gr.Row():
|
| 387 |
+
gr.HTML("") # Empty space to push button to right
|
| 388 |
+
download_results_btn = gr.Button("π₯ CSV", size="sm", scale=1)
|
| 389 |
+
results_csv_output = gr.File(label="", visible=True, scale=1, show_label=False)
|
| 390 |
+
results_table = gr.Dataframe(
|
| 391 |
+
headers=["S.No", "Image Name", "Ground Truth", "Binary Output", "Model Output"],
|
| 392 |
+
label="Results",
|
| 393 |
+
interactive=True, # Make it editable for ground truth input
|
| 394 |
+
col_count=(5, "fixed")
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# Generate Metrics button
|
| 398 |
+
with gr.Row(visible=False) as section3_submit_row:
|
| 399 |
+
with gr.Column():
|
| 400 |
+
submit_results_btn = gr.Button("Generate Metrics", variant="primary", size="lg")
|
| 401 |
+
|
| 402 |
+
# Progress indicator row
|
| 403 |
+
with gr.Row(visible=False) as progress_row:
|
| 404 |
+
progress_message = gr.Markdown("", elem_classes=["progress"])
|
| 405 |
+
|
| 406 |
+
# Section 4: Metrics and confusion matrix
|
| 407 |
+
with gr.Row(visible=False) as section4_metrics_row:
|
| 408 |
+
with gr.Column(scale=2):
|
| 409 |
+
confusion_matrix_plot = gr.Image(
|
| 410 |
+
label="Confusion Matrix"
|
| 411 |
+
)
|
| 412 |
+
with gr.Column(scale=2):
|
| 413 |
+
accuracy_table = gr.Dataframe(
|
| 414 |
+
label="Performance Metrics",
|
| 415 |
+
interactive=False
|
| 416 |
+
)
|
| 417 |
+
confusion_matrix_table = gr.Dataframe(
|
| 418 |
+
label="Confusion Matrix Table",
|
| 419 |
+
interactive=False
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
# State to store folder path
|
| 423 |
+
folder_path_state = gr.State()
|
| 424 |
+
folder_input.change(
|
| 425 |
+
fn=lambda x: x,
|
| 426 |
+
inputs=[folder_input],
|
| 427 |
+
outputs=[folder_path_state]
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Event handlers
|
| 431 |
+
submit_btn.click(
|
| 432 |
+
fn=self.process_input_ui,
|
| 433 |
+
inputs=[folder_input, prompt_input, quantization_dropdown, model_dropdown],
|
| 434 |
+
outputs=[section1_row, section2_prompt_row, section3_results_row, error_message, results_table, image_preview, results_table, section3_submit_row, section4_metrics_row, section1_progress_message, prompt_input_section2, stop_status]
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
results_table.select(
|
| 438 |
+
fn=self.update_image_preview,
|
| 439 |
+
inputs=[results_table, folder_path_state],
|
| 440 |
+
outputs=[image_preview, model_output_display]
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
submit_results_btn.click(
|
| 444 |
+
fn=self.submit_and_show_metrics,
|
| 445 |
+
inputs=[results_table],
|
| 446 |
+
outputs=[results_table, results_table, accuracy_table, confusion_matrix_plot, confusion_matrix_table, section4_metrics_row, progress_row, progress_message]
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
download_results_btn.click(
|
| 450 |
+
fn=self.download_results_csv,
|
| 451 |
+
inputs=[results_table],
|
| 452 |
+
outputs=[results_csv_output]
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
rerun_btn.click(
|
| 456 |
+
fn=self.rerun_ui,
|
| 457 |
+
inputs=[results_table, prompt_input_section2, quantization_dropdown, model_dropdown],
|
| 458 |
+
outputs=[results_table, accuracy_table, confusion_matrix_plot, confusion_matrix_table,
|
| 459 |
+
section1_row, section2_prompt_row, section3_results_row, section4_metrics_row, section3_submit_row, progress_row, progress_message, model_output_display]
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Model change handler to update status
|
| 463 |
+
model_dropdown.change(
|
| 464 |
+
fn=self.on_model_change,
|
| 465 |
+
inputs=[model_dropdown, quantization_dropdown],
|
| 466 |
+
outputs=[model_status]
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
quantization_dropdown.change(
|
| 470 |
+
fn=self.on_model_change,
|
| 471 |
+
inputs=[model_dropdown, quantization_dropdown],
|
| 472 |
+
outputs=[model_status]
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# Stop button click handler
|
| 476 |
+
stop_btn.click(
|
| 477 |
+
fn=self.handle_stop_button,
|
| 478 |
+
inputs=[],
|
| 479 |
+
outputs=[stop_status, stop_status]
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
return demo
|
| 483 |
+
|
| 484 |
+
def launch(self, **kwargs):
|
| 485 |
+
"""Launch the Gradio application."""
|
| 486 |
+
demo = self.create_interface()
|
| 487 |
+
return demo.launch(**kwargs)
|
requirements.txt
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
Pillow
|
| 3 |
+
Requests
|
| 4 |
+
torch
|
| 5 |
+
torchvision
|
| 6 |
+
decord
|
| 7 |
+
git+https://github.com/huggingface/transformers.git
|
| 8 |
+
accelerate
|
| 9 |
+
einops
|
| 10 |
+
timm
|
| 11 |
+
sentencepiece
|
| 12 |
+
gradio>=4.19.2
|
| 13 |
+
torch>=2.2.0
|
| 14 |
+
torchvision>=0.17.0
|
| 15 |
+
transformers>=4.37.2
|
| 16 |
+
pillow>=10.2.0
|
| 17 |
+
accelerate>=0.27.2
|
| 18 |
+
bitsandbytes>=0.42.0
|
| 19 |
+
pandas>=1.5.0
|
| 20 |
+
matplotlib>=3.5.0
|
| 21 |
+
seaborn>=0.11.0
|
| 22 |
+
scikit-learn>=1.0.0
|
| 23 |
+
pyyaml>=6.0.0
|
| 24 |
+
spaces
|
| 25 |
+
boto3
|