Spaces:
Sleeping
Sleeping
First commit - Individual test scripts present
Browse files- .dockerignore +0 -0
- .gitattributes +45 -0
- .gitignore +129 -0
- Dockerfile +75 -0
- README.md +226 -2
- app.py +396 -0
- config.py +327 -0
- docker-compose.yml +0 -0
- requirements.txt +24 -0
- src/__init__.py +10 -0
- src/models/__init__.py +35 -0
- src/models/caption_model.py +490 -0
- src/models/style_model.py +361 -0
- src/utils/__init__.py +54 -0
- src/utils/analytics.py +373 -0
- src/utils/cache_manager.py +403 -0
- src/utils/image_processor.py +373 -0
.dockerignore
ADDED
|
File without changes
|
.gitattributes
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================================
|
| 2 |
+
# GIT LFS CONFIGURATION FOR HUGGINGFACE SPACES
|
| 3 |
+
# ============================================================================
|
| 4 |
+
|
| 5 |
+
# Track large model files with Git LFS
|
| 6 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
|
| 14 |
+
# Track large media files
|
| 15 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.avi filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.mov filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
|
| 21 |
+
# Track large data files
|
| 22 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
|
| 25 |
+
# Python files as text
|
| 26 |
+
*.py text eol=lf
|
| 27 |
+
*.md text eol=lf
|
| 28 |
+
*.txt text eol=lf
|
| 29 |
+
*.json text eol=lf
|
| 30 |
+
*.yaml text eol=lf
|
| 31 |
+
*.yml text eol=lf
|
| 32 |
+
|
| 33 |
+
# Configuration files
|
| 34 |
+
*.toml text eol=lf
|
| 35 |
+
*.ini text eol=lf
|
| 36 |
+
*.cfg text eol=lf
|
| 37 |
+
|
| 38 |
+
# Docker files
|
| 39 |
+
Dockerfile text eol=lf
|
| 40 |
+
*.dockerfile text eol=lf
|
| 41 |
+
docker-compose.yml text eol=lf
|
| 42 |
+
|
| 43 |
+
# Shell scripts
|
| 44 |
+
*.sh text eol=lf
|
| 45 |
+
*.bash text eol=lf
|
.gitignore
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================================
|
| 2 |
+
# AI IMAGE CAPTION GENERATOR - GIT IGNORE
|
| 3 |
+
# ============================================================================
|
| 4 |
+
|
| 5 |
+
# Environment Variables
|
| 6 |
+
.env
|
| 7 |
+
.env.local
|
| 8 |
+
.env.*.local
|
| 9 |
+
|
| 10 |
+
# Python
|
| 11 |
+
__pycache__/
|
| 12 |
+
*.py[cod]
|
| 13 |
+
*$py.class
|
| 14 |
+
*.so
|
| 15 |
+
.Python
|
| 16 |
+
build/
|
| 17 |
+
develop-eggs/
|
| 18 |
+
dist/
|
| 19 |
+
downloads/
|
| 20 |
+
eggs/
|
| 21 |
+
.eggs/
|
| 22 |
+
lib/
|
| 23 |
+
lib64/
|
| 24 |
+
parts/
|
| 25 |
+
sdist/
|
| 26 |
+
var/
|
| 27 |
+
wheels/
|
| 28 |
+
pip-wheel-metadata/
|
| 29 |
+
share/python-wheels/
|
| 30 |
+
*.egg-info/
|
| 31 |
+
.installed.cfg
|
| 32 |
+
*.egg
|
| 33 |
+
MANIFEST
|
| 34 |
+
|
| 35 |
+
# Virtual Environment
|
| 36 |
+
venv/
|
| 37 |
+
env/
|
| 38 |
+
ENV/
|
| 39 |
+
env.bak/
|
| 40 |
+
venv.bak/
|
| 41 |
+
|
| 42 |
+
# IDEs
|
| 43 |
+
.vscode/
|
| 44 |
+
.idea/
|
| 45 |
+
*.swp
|
| 46 |
+
*.swo
|
| 47 |
+
*~
|
| 48 |
+
.DS_Store
|
| 49 |
+
|
| 50 |
+
# Jupyter Notebooks
|
| 51 |
+
.ipynb_checkpoints/
|
| 52 |
+
*.ipynb
|
| 53 |
+
|
| 54 |
+
# Model Cache & Downloads
|
| 55 |
+
cache/
|
| 56 |
+
*.pth
|
| 57 |
+
*.bin
|
| 58 |
+
*.safetensors
|
| 59 |
+
*.onnx
|
| 60 |
+
downloads/
|
| 61 |
+
|
| 62 |
+
# Analytics & Logs
|
| 63 |
+
*.log
|
| 64 |
+
logs/
|
| 65 |
+
analytics.json
|
| 66 |
+
cache/analytics.json
|
| 67 |
+
test_analytics.json
|
| 68 |
+
|
| 69 |
+
# Test Coverage
|
| 70 |
+
.coverage
|
| 71 |
+
.pytest_cache/
|
| 72 |
+
htmlcov/
|
| 73 |
+
.tox/
|
| 74 |
+
.nox/
|
| 75 |
+
coverage.xml
|
| 76 |
+
*.cover
|
| 77 |
+
.hypothesis/
|
| 78 |
+
|
| 79 |
+
# Gradio
|
| 80 |
+
gradio_cached_examples/
|
| 81 |
+
flagged/
|
| 82 |
+
|
| 83 |
+
# Docker
|
| 84 |
+
*.pid
|
| 85 |
+
*.seed
|
| 86 |
+
*.pid.lock
|
| 87 |
+
|
| 88 |
+
# Temporary Files
|
| 89 |
+
*.tmp
|
| 90 |
+
*.temp
|
| 91 |
+
.tmp/
|
| 92 |
+
temp/
|
| 93 |
+
|
| 94 |
+
# OS Files
|
| 95 |
+
Thumbs.db
|
| 96 |
+
.DS_Store
|
| 97 |
+
desktop.ini
|
| 98 |
+
|
| 99 |
+
# Backup Files
|
| 100 |
+
*.bak
|
| 101 |
+
*.backup
|
| 102 |
+
*~
|
| 103 |
+
|
| 104 |
+
# Large Files (use Git LFS instead)
|
| 105 |
+
*.mp4
|
| 106 |
+
*.avi
|
| 107 |
+
*.mov
|
| 108 |
+
*.zip
|
| 109 |
+
*.tar.gz
|
| 110 |
+
|
| 111 |
+
# HuggingFace Spaces specific
|
| 112 |
+
spaces/
|
| 113 |
+
|
| 114 |
+
# PyTorch
|
| 115 |
+
*.pt
|
| 116 |
+
lightning_logs/
|
| 117 |
+
|
| 118 |
+
# Profiling
|
| 119 |
+
*.prof
|
| 120 |
+
|
| 121 |
+
# Documentation builds
|
| 122 |
+
docs/_build/
|
| 123 |
+
docs/.doctrees/
|
| 124 |
+
|
| 125 |
+
# ============================================================================
|
| 126 |
+
# KEEP THESE DIRECTORIES (create .gitkeep files)
|
| 127 |
+
# ============================================================================
|
| 128 |
+
!cache/.gitkeep
|
| 129 |
+
!static/images/examples/.gitkeep
|
Dockerfile
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================================
|
| 2 |
+
# AI IMAGE CAPTION GENERATOR - DOCKERFILE
|
| 3 |
+
# ============================================================================
|
| 4 |
+
# Multi-stage build for optimized production image
|
| 5 |
+
# Compatible with HuggingFace Spaces and local deployment
|
| 6 |
+
# ============================================================================
|
| 7 |
+
|
| 8 |
+
FROM python:3.10-slim as base
|
| 9 |
+
|
| 10 |
+
# Set environment variables
|
| 11 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 12 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 13 |
+
PIP_NO_CACHE_DIR=1 \
|
| 14 |
+
PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
| 15 |
+
DEBIAN_FRONTEND=noninteractive
|
| 16 |
+
|
| 17 |
+
# Set working directory
|
| 18 |
+
WORKDIR /app
|
| 19 |
+
|
| 20 |
+
# ============================================================================
|
| 21 |
+
# DEPENDENCIES STAGE
|
| 22 |
+
# ============================================================================
|
| 23 |
+
|
| 24 |
+
FROM base as dependencies
|
| 25 |
+
|
| 26 |
+
# Install system dependencies
|
| 27 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 28 |
+
build-essential \
|
| 29 |
+
git \
|
| 30 |
+
curl \
|
| 31 |
+
libgl1-mesa-glx \
|
| 32 |
+
libglib2.0-0 \
|
| 33 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 34 |
+
|
| 35 |
+
# Copy requirements
|
| 36 |
+
COPY requirements.txt .
|
| 37 |
+
|
| 38 |
+
# Install Python dependencies
|
| 39 |
+
RUN pip install --upgrade pip && \
|
| 40 |
+
pip install -r requirements.txt
|
| 41 |
+
|
| 42 |
+
# ============================================================================
|
| 43 |
+
# RUNTIME STAGE
|
| 44 |
+
# ============================================================================
|
| 45 |
+
|
| 46 |
+
FROM base as runtime
|
| 47 |
+
|
| 48 |
+
# Install runtime dependencies only
|
| 49 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 50 |
+
libgl1-mesa-glx \
|
| 51 |
+
libglib2.0-0 \
|
| 52 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 53 |
+
|
| 54 |
+
# Copy Python dependencies from builder
|
| 55 |
+
COPY --from=dependencies /usr/local/lib/python3.10/site-packages /usr/local/lib/python3.10/site-packages
|
| 56 |
+
COPY --from=dependencies /usr/local/bin /usr/local/bin
|
| 57 |
+
|
| 58 |
+
# Copy application code
|
| 59 |
+
COPY . .
|
| 60 |
+
|
| 61 |
+
# Create necessary directories
|
| 62 |
+
RUN mkdir -p cache/models cache/analytics static/images/examples
|
| 63 |
+
|
| 64 |
+
# Set permissions
|
| 65 |
+
RUN chmod -R 755 /app
|
| 66 |
+
|
| 67 |
+
# Expose Gradio default port
|
| 68 |
+
EXPOSE 7860
|
| 69 |
+
|
| 70 |
+
# Health check
|
| 71 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
| 72 |
+
CMD curl -f http://localhost:7860/ || exit 1
|
| 73 |
+
|
| 74 |
+
# Run the application
|
| 75 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
|
@@ -1,2 +1,226 @@
|
|
| 1 |
-
#
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🖼️ AI Image Caption Generator
|
| 2 |
+
|
| 3 |
+
[](https://www.python.org/downloads/)
|
| 4 |
+
[](https://pytorch.org/)
|
| 5 |
+
[](https://opensource.org/licenses/MIT)
|
| 6 |
+
[](https://huggingface.co/spaces/ChinmayM06/ai-image-caption-generator)
|
| 7 |
+
|
| 8 |
+
> Generate AI-powered image captions with multiple style options—completely free, no API costs.
|
| 9 |
+
|
| 10 |
+
A lightweight, GPU-accelerated image captioning tool using state-of-the-art vision-language models (BLIP & GIT) with style customization powered by Groq's free LLM API.
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## ✨ Features
|
| 15 |
+
|
| 16 |
+
- 🎯 **Dual Model Support**: Both BLIP-base (fast) and GIT-large (high quality) run simultaneously
|
| 17 |
+
- 🎨 **5 Caption Styles**: None, Creative, Social Media, Professional, Technical
|
| 18 |
+
- ⚡ **GPU Accelerated**: Optimized for NVIDIA GPUs (works on CPU too)
|
| 19 |
+
- 💾 **Smart Caching**: LRU cache with configurable TTL for faster repeated requests
|
| 20 |
+
- 📊 **Analytics Tracking**: Built-in usage statistics and performance metrics
|
| 21 |
+
- 🖼️ **Image Processing**: Automatic validation, resizing, and format conversion
|
| 22 |
+
- 🔄 **Fallback Mechanisms**: Graceful degradation when API is unavailable
|
| 23 |
+
- 💰 **100% Free**: No OpenAI credits, no hidden costs
|
| 24 |
+
- 🔒 **Privacy First**: Local inference option available
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## 🚀 Live Demo
|
| 29 |
+
|
| 30 |
+
Try it out without any installation:
|
| 31 |
+
|
| 32 |
+
**[🎮 Launch Live Demo →](https://huggingface.co/spaces/ChinmayM06/ai-image-caption-generator)**
|
| 33 |
+
|
| 34 |
+
*Add your Hugging Face Spaces URL above after deployment*
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
## 🛠️ Tech Stack
|
| 39 |
+
|
| 40 |
+
| Component | Technology |
|
| 41 |
+
|-----------|-----------|
|
| 42 |
+
| **Vision Models** | BLIP-base, GIT-large (Hugging Face) |
|
| 43 |
+
| **Style LLM** | Groq API (free tier) |
|
| 44 |
+
| **Framework** | PyTorch 2.1.0 + CUDA 11.8 |
|
| 45 |
+
| **Interface** | Gradio 4.8.0 |
|
| 46 |
+
| **Deployment** | Hugging Face Spaces (T4 GPU) |
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## 📦 Quick Start
|
| 51 |
+
|
| 52 |
+
### Prerequisites
|
| 53 |
+
|
| 54 |
+
- Python 3.10+
|
| 55 |
+
- NVIDIA GPU with 4GB+ VRAM (recommended) or CPU
|
| 56 |
+
- CUDA 11.8 (for GPU acceleration)
|
| 57 |
+
|
| 58 |
+
### Installation
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
# Clone repository
|
| 62 |
+
git clone https://github.com/ChinmayM06/ai-image-caption-generator.git
|
| 63 |
+
cd ai-image-caption-generator
|
| 64 |
+
|
| 65 |
+
# Create virtual environment
|
| 66 |
+
python -m venv venv
|
| 67 |
+
source venv/bin/activate # Windows: venv\Scripts\activate
|
| 68 |
+
|
| 69 |
+
# Install PyTorch with CUDA support
|
| 70 |
+
pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118
|
| 71 |
+
|
| 72 |
+
# Install dependencies
|
| 73 |
+
pip install -r requirements.txt
|
| 74 |
+
|
| 75 |
+
# Set up environment variables (optional)
|
| 76 |
+
# Create a .env file in the project root with:
|
| 77 |
+
# GROQ_API_KEY=your_groq_api_key_here
|
| 78 |
+
# Get your free API key at https://console.groq.com
|
| 79 |
+
# Note: The app works without API key but styling features will use fallback templates
|
| 80 |
+
|
| 81 |
+
# Run the application
|
| 82 |
+
python app.py
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Access at `http://localhost:7860`
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
## 🎯 Usage
|
| 90 |
+
|
| 91 |
+
### Basic Usage
|
| 92 |
+
|
| 93 |
+
```python
|
| 94 |
+
from src.models import get_model_manager, get_style_model
|
| 95 |
+
from src.utils import get_image_processor
|
| 96 |
+
from PIL import Image
|
| 97 |
+
|
| 98 |
+
# Initialize components (singleton pattern)
|
| 99 |
+
model_manager = get_model_manager()
|
| 100 |
+
style_model = get_style_model()
|
| 101 |
+
image_processor = get_image_processor()
|
| 102 |
+
|
| 103 |
+
# Load models (BLIP and GIT)
|
| 104 |
+
blip_success, git_success = model_manager.load_all_models()
|
| 105 |
+
|
| 106 |
+
# Load and preprocess image
|
| 107 |
+
image = Image.open("your_image.jpg")
|
| 108 |
+
processed_img, metadata = image_processor.preprocess_image(image)
|
| 109 |
+
|
| 110 |
+
# Generate captions from both models
|
| 111 |
+
captions = model_manager.generate_captions(processed_img)
|
| 112 |
+
blip_caption = captions["blip"]
|
| 113 |
+
git_caption = captions["git"]
|
| 114 |
+
|
| 115 |
+
# Apply style (optional)
|
| 116 |
+
styled_blip = style_model.style_caption(blip_caption, style="Professional")
|
| 117 |
+
styled_git = style_model.style_caption(git_caption, style="Creative")
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
### Available Models
|
| 121 |
+
|
| 122 |
+
Both models run simultaneously to provide comparison:
|
| 123 |
+
- **BLIP-base**: Fast inference (~1-2s), good quality, efficient
|
| 124 |
+
- **GIT-large**: Slower (~3-4s), superior caption quality, more detailed
|
| 125 |
+
|
| 126 |
+
### Caption Styles
|
| 127 |
+
|
| 128 |
+
| Style | Use Case | Example |
|
| 129 |
+
|-------|----------|---------|
|
| 130 |
+
| **None** | Raw model output | "A dog sitting on grass" |
|
| 131 |
+
| **Creative** | Artistic, imaginative | "A joyful golden retriever basking in nature's embrace" |
|
| 132 |
+
| **Social Media** | Engaging, hashtag-ready | "Meet this good boy enjoying sunny vibes! 🐕☀️ #DogLife" |
|
| 133 |
+
| **Professional** | Business, formal | "Canine subject positioned in outdoor environment" |
|
| 134 |
+
| **Technical** | Detailed, analytical | "Golden retriever breed, seated posture, natural lighting, outdoor setting" |
|
| 135 |
+
|
| 136 |
+
---
|
| 137 |
+
|
| 138 |
+
## 🐳 Docker Deployment
|
| 139 |
+
|
| 140 |
+
```bash
|
| 141 |
+
# Build image
|
| 142 |
+
docker build -t caption-generator .
|
| 143 |
+
|
| 144 |
+
# Run container (with GPU)
|
| 145 |
+
docker run --gpus all -p 7860:7860 caption-generator
|
| 146 |
+
|
| 147 |
+
# Run container (CPU only)
|
| 148 |
+
docker run -p 7860:7860 -e DEVICE=cpu caption-generator
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
---
|
| 152 |
+
|
| 153 |
+
## ⚙️ Configuration
|
| 154 |
+
|
| 155 |
+
### Environment Variables
|
| 156 |
+
|
| 157 |
+
Create a `.env` file in the project root (optional):
|
| 158 |
+
|
| 159 |
+
```bash
|
| 160 |
+
# Groq API Key (required for advanced styling, fallback available)
|
| 161 |
+
GROQ_API_KEY=your_groq_api_key_here
|
| 162 |
+
|
| 163 |
+
# Hardware Configuration (optional, defaults to 'cuda' if available)
|
| 164 |
+
DEVICE=cuda # or 'cpu'
|
| 165 |
+
|
| 166 |
+
# Logging Level (optional)
|
| 167 |
+
LOG_LEVEL=INFO # DEBUG, INFO, WARNING, ERROR
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
---
|
| 171 |
+
|
| 172 |
+
## 🎓 Why This Project?
|
| 173 |
+
|
| 174 |
+
Built as a learning project to explore:
|
| 175 |
+
- **GenAI Fundamentals**: Vision-language models, prompt engineering
|
| 176 |
+
- **Practical ML Skills**: GPU optimization, model deployment, API integration
|
| 177 |
+
- **Cost Optimization**: Demonstrating production-quality AI without expensive APIs
|
| 178 |
+
- **Software Architecture**: Caching, analytics, error handling, thread safety
|
| 179 |
+
|
| 180 |
+
Perfect for understanding how modern image captioning works under the hood while keeping infrastructure costs at zero.
|
| 181 |
+
|
| 182 |
+
---
|
| 183 |
+
|
| 184 |
+
## 🤝 Contributing
|
| 185 |
+
|
| 186 |
+
Contributions welcome! Feel free to:
|
| 187 |
+
- Report bugs
|
| 188 |
+
- Suggest features
|
| 189 |
+
- Submit pull requests
|
| 190 |
+
- Improve documentation
|
| 191 |
+
- Add new caption styles
|
| 192 |
+
- Optimize performance
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
---
|
| 196 |
+
|
| 197 |
+
## 📝 License
|
| 198 |
+
|
| 199 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 200 |
+
|
| 201 |
+
---
|
| 202 |
+
|
| 203 |
+
## 🙏 Acknowledgments
|
| 204 |
+
|
| 205 |
+
- [Salesforce BLIP](https://github.com/salesforce/BLIP) - Image captioning model
|
| 206 |
+
- [Microsoft GIT](https://github.com/microsoft/GenerativeImage2Text) - High-quality captions
|
| 207 |
+
- [Groq](https://groq.com) - Free LLM inference API
|
| 208 |
+
- [Hugging Face](https://huggingface.co) - Model hosting & deployment
|
| 209 |
+
|
| 210 |
+
---
|
| 211 |
+
|
| 212 |
+
## 📬 Contact
|
| 213 |
+
|
| 214 |
+
**Chinmay M** - [@ChinmayM06](https://github.com/ChinmayM06)
|
| 215 |
+
|
| 216 |
+
Project Link: [https://github.com/ChinmayM06/ai-image-caption-generator](https://github.com/ChinmayM06/ai-image-caption-generator)
|
| 217 |
+
|
| 218 |
+
---
|
| 219 |
+
|
| 220 |
+
<div align="center">
|
| 221 |
+
|
| 222 |
+
**[⭐ Star this repo](https://github.com/ChinmayM06/ai-image-caption-generator)** if you find it helpful!
|
| 223 |
+
|
| 224 |
+
Made with ❤️ and lots of ☕
|
| 225 |
+
|
| 226 |
+
</div>
|
app.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AI Image Caption Generator - Main Application
|
| 3 |
+
|
| 4 |
+
Gradio-based web interface for generating image captions using BLIP and GIT models
|
| 5 |
+
with customizable styling via Groq API.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import time
|
| 10 |
+
import numpy as np
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from typing import Tuple, Optional
|
| 14 |
+
|
| 15 |
+
# Import our modules
|
| 16 |
+
from config import ui_config, performance_config
|
| 17 |
+
from src.utils import (
|
| 18 |
+
get_image_processor,
|
| 19 |
+
get_caption_cache,
|
| 20 |
+
get_analytics_manager,
|
| 21 |
+
ImageProcessingError
|
| 22 |
+
)
|
| 23 |
+
from src.models import (
|
| 24 |
+
get_model_manager,
|
| 25 |
+
get_style_model,
|
| 26 |
+
CaptionModelError,
|
| 27 |
+
StyleModelError
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CaptionGeneratorApp:
|
| 32 |
+
"""
|
| 33 |
+
Main application class for the caption generator
|
| 34 |
+
|
| 35 |
+
Manages the Gradio interface and coordinates all components
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
"""Initialize the application"""
|
| 40 |
+
print("=" * 60)
|
| 41 |
+
print("🚀 INITIALIZING AI IMAGE CAPTION GENERATOR")
|
| 42 |
+
print("=" * 60)
|
| 43 |
+
|
| 44 |
+
# Initialize components
|
| 45 |
+
self.image_processor = get_image_processor()
|
| 46 |
+
self.model_manager = get_model_manager()
|
| 47 |
+
self.style_model = get_style_model()
|
| 48 |
+
self.cache = get_caption_cache()
|
| 49 |
+
self.analytics = get_analytics_manager()
|
| 50 |
+
|
| 51 |
+
print("\n✓ Components initialized")
|
| 52 |
+
|
| 53 |
+
# Load models
|
| 54 |
+
print("\n📦 Loading AI models (this may take a few minutes on first run)...")
|
| 55 |
+
blip_success, git_success = self.model_manager.load_all_models()
|
| 56 |
+
|
| 57 |
+
if not (blip_success and git_success):
|
| 58 |
+
print("\n⚠️ Warning: Some models failed to load")
|
| 59 |
+
print(f" BLIP: {'✓' if blip_success else '✗'}")
|
| 60 |
+
print(f" GIT: {'✓' if git_success else '✗'}")
|
| 61 |
+
else:
|
| 62 |
+
print("\n✓ All models loaded successfully")
|
| 63 |
+
|
| 64 |
+
# Check style model
|
| 65 |
+
if self.style_model.is_api_available():
|
| 66 |
+
print("✓ Groq API connected")
|
| 67 |
+
else:
|
| 68 |
+
print("⚠️ Groq API not available - using fallback styling")
|
| 69 |
+
|
| 70 |
+
print("\n" + "=" * 60)
|
| 71 |
+
print("✅ INITIALIZATION COMPLETE")
|
| 72 |
+
print("=" * 60 + "\n")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def generate_captions(
|
| 76 |
+
self,
|
| 77 |
+
image, # Changed: Can be path or PIL Image
|
| 78 |
+
style: str,
|
| 79 |
+
progress=gr.Progress()
|
| 80 |
+
) -> Tuple[str, str, str]:
|
| 81 |
+
"""
|
| 82 |
+
Generate captions for an image
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
image: Image path (str) or PIL Image
|
| 86 |
+
style: Style to apply
|
| 87 |
+
progress: Gradio progress tracker
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
Tuple[str, str, str]: (blip_caption, git_caption, stats_text)
|
| 91 |
+
"""
|
| 92 |
+
start_time = time.time()
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
# Step 1: Validate and preprocess image
|
| 96 |
+
progress(0.1, desc="Validating image...")
|
| 97 |
+
|
| 98 |
+
if image is None:
|
| 99 |
+
return (
|
| 100 |
+
"❌ Error: No image provided",
|
| 101 |
+
"❌ Error: No image provided",
|
| 102 |
+
"⚠️ Please upload an image"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Convert to PIL Image from various formats
|
| 106 |
+
try:
|
| 107 |
+
if isinstance(image, str):
|
| 108 |
+
# File path
|
| 109 |
+
pil_image = Image.open(image)
|
| 110 |
+
elif isinstance(image, Image.Image):
|
| 111 |
+
# Already PIL Image
|
| 112 |
+
pil_image = image
|
| 113 |
+
elif hasattr(image, 'shape'):
|
| 114 |
+
# Numpy array
|
| 115 |
+
import numpy as np
|
| 116 |
+
if isinstance(image, np.ndarray):
|
| 117 |
+
pil_image = Image.fromarray(image.astype('uint8'))
|
| 118 |
+
else:
|
| 119 |
+
raise ValueError("Unsupported array type")
|
| 120 |
+
else:
|
| 121 |
+
return (
|
| 122 |
+
f"❌ Error: Unsupported image type: {type(image)}",
|
| 123 |
+
f"❌ Error: Unsupported image type: {type(image)}",
|
| 124 |
+
"⚠️ Image format not supported"
|
| 125 |
+
)
|
| 126 |
+
except Exception as e:
|
| 127 |
+
return (
|
| 128 |
+
f"❌ Error: Cannot load image - {str(e)}",
|
| 129 |
+
f"❌ Error: Cannot load image - {str(e)}",
|
| 130 |
+
"⚠️ Image loading failed"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Validate image
|
| 134 |
+
is_valid, error_msg = self.image_processor.validate_image(pil_image)
|
| 135 |
+
if not is_valid:
|
| 136 |
+
return (
|
| 137 |
+
f"❌ Error: {error_msg}",
|
| 138 |
+
f"❌ Error: {error_msg}",
|
| 139 |
+
"⚠️ Image validation failed"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Preprocess image
|
| 143 |
+
progress(0.2, desc="Processing image...")
|
| 144 |
+
processed_img, metadata = self.image_processor.preprocess_image(pil_image)
|
| 145 |
+
|
| 146 |
+
# Generate image hash for caching
|
| 147 |
+
image_hash = self.image_processor.generate_image_hash(processed_img)
|
| 148 |
+
|
| 149 |
+
# Step 2: Check cache
|
| 150 |
+
progress(0.3, desc="Checking cache...")
|
| 151 |
+
|
| 152 |
+
blip_cached = self.cache.get_caption(image_hash, "blip", style)
|
| 153 |
+
git_cached = self.cache.get_caption(image_hash, "git", style)
|
| 154 |
+
|
| 155 |
+
# Step 3: Generate captions if not cached
|
| 156 |
+
raw_captions = {}
|
| 157 |
+
|
| 158 |
+
if blip_cached is None or git_cached is None:
|
| 159 |
+
progress(0.4, desc="Generating captions...")
|
| 160 |
+
raw_captions = self.model_manager.generate_captions(processed_img)
|
| 161 |
+
|
| 162 |
+
# Step 4: Apply styling
|
| 163 |
+
progress(0.6, desc=f"Applying {style} style...")
|
| 164 |
+
|
| 165 |
+
styled_captions = {}
|
| 166 |
+
|
| 167 |
+
# BLIP caption
|
| 168 |
+
if blip_cached:
|
| 169 |
+
styled_captions["blip"] = blip_cached
|
| 170 |
+
else:
|
| 171 |
+
blip_raw = raw_captions.get("blip", "Error generating caption")
|
| 172 |
+
styled_captions["blip"] = self.style_model.style_caption(blip_raw, style)
|
| 173 |
+
self.cache.set_caption(image_hash, "blip", style, styled_captions["blip"])
|
| 174 |
+
|
| 175 |
+
# GIT caption
|
| 176 |
+
if git_cached:
|
| 177 |
+
styled_captions["git"] = git_cached
|
| 178 |
+
else:
|
| 179 |
+
git_raw = raw_captions.get("git", "Error generating caption")
|
| 180 |
+
styled_captions["git"] = self.style_model.style_caption(git_raw, style)
|
| 181 |
+
self.cache.set_caption(image_hash, "git", style, styled_captions["git"])
|
| 182 |
+
|
| 183 |
+
# Step 5: Record analytics
|
| 184 |
+
progress(0.9, desc="Finalizing...")
|
| 185 |
+
|
| 186 |
+
processing_time = time.time() - start_time
|
| 187 |
+
|
| 188 |
+
# Record for each model
|
| 189 |
+
self.analytics.record_caption_generation("blip", style, processing_time / 2, True)
|
| 190 |
+
self.analytics.record_caption_generation("git", style, processing_time / 2, True)
|
| 191 |
+
|
| 192 |
+
# Get stats
|
| 193 |
+
stats_text = self.analytics.get_display_stats()
|
| 194 |
+
stats_text += f" | ⏱️ This generation: {processing_time:.2f}s"
|
| 195 |
+
|
| 196 |
+
progress(1.0, desc="Complete!")
|
| 197 |
+
|
| 198 |
+
return (
|
| 199 |
+
styled_captions.get("blip", "Error"),
|
| 200 |
+
styled_captions.get("git", "Error"),
|
| 201 |
+
stats_text
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
except ImageProcessingError as e:
|
| 205 |
+
error_msg = f"❌ Image Error: {str(e)}"
|
| 206 |
+
return error_msg, error_msg, "⚠️ Image processing failed"
|
| 207 |
+
|
| 208 |
+
except CaptionModelError as e:
|
| 209 |
+
error_msg = f"❌ Model Error: {str(e)}"
|
| 210 |
+
return error_msg, error_msg, "⚠️ Caption generation failed"
|
| 211 |
+
|
| 212 |
+
except Exception as e:
|
| 213 |
+
error_msg = f"❌ Unexpected Error: {str(e)}"
|
| 214 |
+
print(f"Error in generate_captions: {e}")
|
| 215 |
+
|
| 216 |
+
# Record error
|
| 217 |
+
self.analytics.record_caption_generation("unknown", style, 0, False)
|
| 218 |
+
|
| 219 |
+
return error_msg, error_msg, "⚠️ An error occurred"
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def create_interface(self) -> gr.Blocks:
|
| 223 |
+
"""
|
| 224 |
+
Create Gradio interface
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
gr.Blocks: Configured Gradio interface
|
| 228 |
+
"""
|
| 229 |
+
with gr.Blocks(
|
| 230 |
+
theme=gr.themes.Soft(),
|
| 231 |
+
title=ui_config.TITLE,
|
| 232 |
+
css=self._get_custom_css()
|
| 233 |
+
) as interface:
|
| 234 |
+
|
| 235 |
+
# Header
|
| 236 |
+
gr.Markdown(f"# {ui_config.TITLE}")
|
| 237 |
+
gr.Markdown(ui_config.DESCRIPTION)
|
| 238 |
+
|
| 239 |
+
with gr.Row():
|
| 240 |
+
with gr.Column(scale=1):
|
| 241 |
+
# Input section
|
| 242 |
+
gr.Markdown("### 📤 Upload Image")
|
| 243 |
+
|
| 244 |
+
image_input = gr.Image(
|
| 245 |
+
label="Upload your image",
|
| 246 |
+
type="pil",
|
| 247 |
+
height=ui_config.IMAGE_HEIGHT
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
style_dropdown = gr.Dropdown(
|
| 251 |
+
choices=self.style_model.get_available_styles(),
|
| 252 |
+
value="Professional",
|
| 253 |
+
label="🎨 Choose Caption Style",
|
| 254 |
+
info="Select how you want your caption to be styled"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
generate_btn = gr.Button(
|
| 258 |
+
"✨ Generate Captions",
|
| 259 |
+
variant="primary",
|
| 260 |
+
size="lg"
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
with gr.Column(scale=1):
|
| 264 |
+
# Output section
|
| 265 |
+
gr.Markdown("### 📝 Generated Captions")
|
| 266 |
+
|
| 267 |
+
with gr.Group():
|
| 268 |
+
gr.Markdown("**🤖 BLIP Caption**")
|
| 269 |
+
blip_output = gr.Textbox(
|
| 270 |
+
label="",
|
| 271 |
+
placeholder="BLIP caption will appear here...",
|
| 272 |
+
lines=3,
|
| 273 |
+
show_copy_button=True
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
with gr.Group():
|
| 277 |
+
gr.Markdown("**🤖 GIT Caption**")
|
| 278 |
+
git_output = gr.Textbox(
|
| 279 |
+
label="",
|
| 280 |
+
placeholder="GIT caption will appear here...",
|
| 281 |
+
lines=3,
|
| 282 |
+
show_copy_button=True
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Statistics section
|
| 286 |
+
with gr.Row():
|
| 287 |
+
stats_display = gr.Markdown(
|
| 288 |
+
value=self.analytics.get_display_stats(),
|
| 289 |
+
elem_id="stats-display"
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# Examples section (if examples exist)
|
| 293 |
+
examples_dir = ui_config.EXAMPLES_DIR
|
| 294 |
+
if examples_dir.exists() and list(examples_dir.glob("*.jpg")):
|
| 295 |
+
gr.Markdown("### 💡 Try These Examples")
|
| 296 |
+
gr.Examples(
|
| 297 |
+
examples=[str(p) for p in examples_dir.glob("*.jpg")[:3]],
|
| 298 |
+
inputs=image_input,
|
| 299 |
+
label=""
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Footer
|
| 303 |
+
gr.Markdown(
|
| 304 |
+
"""
|
| 305 |
+
---
|
| 306 |
+
<div style='text-align: center; color: #666; font-size: 0.9em;'>
|
| 307 |
+
<p>🚀 Powered by BLIP, GIT, and Groq API | Built with ❤️ using Gradio</p>
|
| 308 |
+
<p>⚡ Free and Open Source | 📊 All processing done securely</p>
|
| 309 |
+
</div>
|
| 310 |
+
""",
|
| 311 |
+
elem_id="footer"
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Event handlers
|
| 315 |
+
generate_btn.click(
|
| 316 |
+
fn=self.generate_captions,
|
| 317 |
+
inputs=[image_input, style_dropdown],
|
| 318 |
+
outputs=[blip_output, git_output, stats_display],
|
| 319 |
+
api_name="generate"
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
return interface
|
| 323 |
+
|
| 324 |
+
def _get_custom_css(self) -> str:
|
| 325 |
+
"""Get custom CSS for the interface"""
|
| 326 |
+
return """
|
| 327 |
+
#stats-display {
|
| 328 |
+
padding: 15px;
|
| 329 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 330 |
+
color: white;
|
| 331 |
+
border-radius: 10px;
|
| 332 |
+
text-align: center;
|
| 333 |
+
font-weight: 500;
|
| 334 |
+
margin: 20px 0;
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
#footer {
|
| 338 |
+
margin-top: 30px;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
.gr-button-primary {
|
| 342 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
|
| 343 |
+
border: none !important;
|
| 344 |
+
font-weight: 600 !important;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
.gr-button-primary:hover {
|
| 348 |
+
transform: translateY(-2px);
|
| 349 |
+
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
|
| 350 |
+
transition: all 0.3s ease;
|
| 351 |
+
}
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
def launch(
|
| 355 |
+
self,
|
| 356 |
+
share: bool = False,
|
| 357 |
+
server_name: str = "0.0.0.0",
|
| 358 |
+
server_port: int = 7860
|
| 359 |
+
):
|
| 360 |
+
"""
|
| 361 |
+
Launch the Gradio interface
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
share: Create public URL
|
| 365 |
+
server_name: Server host
|
| 366 |
+
server_port: Server port
|
| 367 |
+
"""
|
| 368 |
+
interface = self.create_interface()
|
| 369 |
+
|
| 370 |
+
interface.launch(
|
| 371 |
+
share=share,
|
| 372 |
+
server_name=server_name,
|
| 373 |
+
server_port=server_port,
|
| 374 |
+
show_api=ui_config.SHOW_API,
|
| 375 |
+
show_error=ui_config.SHOW_ERROR
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def main():
|
| 380 |
+
"""Main entry point"""
|
| 381 |
+
try:
|
| 382 |
+
app = CaptionGeneratorApp()
|
| 383 |
+
app.launch(
|
| 384 |
+
share=False, # Set to True to create public URL
|
| 385 |
+
server_name="0.0.0.0",
|
| 386 |
+
server_port=7860
|
| 387 |
+
)
|
| 388 |
+
except KeyboardInterrupt:
|
| 389 |
+
print("\n\n👋 Shutting down gracefully...")
|
| 390 |
+
except Exception as e:
|
| 391 |
+
print(f"\n❌ Fatal error: {e}")
|
| 392 |
+
raise
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
if __name__ == "__main__":
|
| 396 |
+
main()
|
config.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Centralized Configuration Module
|
| 3 |
+
|
| 4 |
+
This module contains all configuration settings for the AI Image Caption Generator.
|
| 5 |
+
Follows the single source of truth principle for easy maintenance and deployment.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Dict, List, Final
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
|
| 14 |
+
# Load environment variables
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
# ============================================================================
|
| 18 |
+
# PROJECT PATHS
|
| 19 |
+
# ============================================================================
|
| 20 |
+
|
| 21 |
+
PROJECT_ROOT: Final[Path] = Path(__file__).parent
|
| 22 |
+
CACHE_DIR: Final[Path] = PROJECT_ROOT / "cache"
|
| 23 |
+
MODEL_CACHE_DIR: Final[Path] = CACHE_DIR / "models"
|
| 24 |
+
ANALYTICS_FILE: Final[Path] = CACHE_DIR / "analytics.json"
|
| 25 |
+
STATIC_DIR: Final[Path] = PROJECT_ROOT / "static"
|
| 26 |
+
|
| 27 |
+
# Create directories if they don't exist
|
| 28 |
+
for directory in [CACHE_DIR, MODEL_CACHE_DIR, STATIC_DIR]:
|
| 29 |
+
directory.mkdir(parents=True, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ============================================================================
|
| 33 |
+
# MODEL CONFIGURATION
|
| 34 |
+
# ============================================================================
|
| 35 |
+
|
| 36 |
+
@dataclass(frozen=True)
|
| 37 |
+
class ModelConfig:
|
| 38 |
+
"""Configuration for caption generation models"""
|
| 39 |
+
|
| 40 |
+
# BLIP Model
|
| 41 |
+
BLIP_MODEL_NAME: str = "Salesforce/blip-image-captioning-base"
|
| 42 |
+
BLIP_MAX_LENGTH: int = 50
|
| 43 |
+
BLIP_NUM_BEAMS: int = 3
|
| 44 |
+
|
| 45 |
+
# GIT Model
|
| 46 |
+
GIT_MODEL_NAME: str = "microsoft/git-large-coco"
|
| 47 |
+
GIT_MAX_LENGTH: int = 50
|
| 48 |
+
GIT_NUM_BEAMS: int = 3
|
| 49 |
+
|
| 50 |
+
# Device Configuration
|
| 51 |
+
DEVICE: str = "cuda" # Will auto-fallback to CPU if CUDA unavailable
|
| 52 |
+
|
| 53 |
+
# Memory Management
|
| 54 |
+
MODEL_CACHE_DIR: Path = MODEL_CACHE_DIR
|
| 55 |
+
LOW_MEMORY_MODE: bool = False # Enable for systems with <8GB GPU memory
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ============================================================================
|
| 59 |
+
# IMAGE PROCESSING CONFIGURATION
|
| 60 |
+
# ============================================================================
|
| 61 |
+
|
| 62 |
+
@dataclass(frozen=True)
|
| 63 |
+
class ImageConfig:
|
| 64 |
+
"""Configuration for image validation and preprocessing"""
|
| 65 |
+
|
| 66 |
+
# Size Constraints
|
| 67 |
+
MAX_FILE_SIZE_MB: int = 5
|
| 68 |
+
MAX_FILE_SIZE_BYTES: int = MAX_FILE_SIZE_MB * 1024 * 1024
|
| 69 |
+
MAX_DIMENSION: int = 512 # Max width/height for model input
|
| 70 |
+
MIN_DIMENSION: int = 32 # Minimum acceptable dimension
|
| 71 |
+
|
| 72 |
+
# Supported Formats
|
| 73 |
+
ALLOWED_FORMATS: tuple = ("JPEG", "PNG", "WEBP", "JPG")
|
| 74 |
+
ALLOWED_EXTENSIONS: tuple = (".jpg", ".jpeg", ".png", ".webp")
|
| 75 |
+
|
| 76 |
+
# Processing
|
| 77 |
+
RESIZE_QUALITY: int = 95 # JPEG quality after resize
|
| 78 |
+
MAINTAIN_ASPECT_RATIO: bool = True
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ============================================================================
|
| 82 |
+
# GROQ API CONFIGURATION
|
| 83 |
+
# ============================================================================
|
| 84 |
+
|
| 85 |
+
@dataclass(frozen=True)
|
| 86 |
+
class GroqConfig:
|
| 87 |
+
"""Configuration for Groq API styling"""
|
| 88 |
+
|
| 89 |
+
# API Settings
|
| 90 |
+
API_KEY: str = os.getenv("GROQ_API_KEY", "")
|
| 91 |
+
MODEL_NAME: str = "llama-3.1-8b-instant"
|
| 92 |
+
|
| 93 |
+
# Request Parameters
|
| 94 |
+
MAX_TOKENS: int = 150
|
| 95 |
+
TEMPERATURE: float = 0.7
|
| 96 |
+
TOP_P: float = 0.9
|
| 97 |
+
TIMEOUT_SECONDS: int = 10
|
| 98 |
+
|
| 99 |
+
# Retry Logic
|
| 100 |
+
MAX_RETRIES: int = 3
|
| 101 |
+
RETRY_DELAY_SECONDS: float = 1.0
|
| 102 |
+
|
| 103 |
+
# Rate Limiting
|
| 104 |
+
REQUESTS_PER_MINUTE: int = 30
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# ============================================================================
|
| 108 |
+
# STYLE CONFIGURATION
|
| 109 |
+
# ============================================================================
|
| 110 |
+
|
| 111 |
+
class StyleConfig:
|
| 112 |
+
"""Configuration for caption styling options"""
|
| 113 |
+
|
| 114 |
+
STYLES: Final[Dict[str, str]] = {
|
| 115 |
+
"None": "Keep the original caption without any modifications.",
|
| 116 |
+
"Professional": "Rewrite this image caption in a professional, business-appropriate tone. Make it clear, formal, and suitable for corporate presentations or reports.",
|
| 117 |
+
"Creative": "Transform this caption into a creative, artistic, and imaginative description. Use vivid language and engaging expressions.",
|
| 118 |
+
"Social Media": "Rewrite this caption for social media platforms. Make it engaging, add relevant emojis, and make it shareable. Keep it under 280 characters.",
|
| 119 |
+
"Technical": "Rewrite this caption with technical precision and detailed analysis. Focus on specific elements, composition, and visual characteristics."
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
DEFAULT_STYLE: Final[str] = "Professional"
|
| 123 |
+
|
| 124 |
+
# Fallback templates when API fails
|
| 125 |
+
FALLBACK_TEMPLATES: Final[Dict[str, str]] = {
|
| 126 |
+
"Professional": "Image Description: {caption}",
|
| 127 |
+
"Creative": "✨ {caption} ✨",
|
| 128 |
+
"Social Media": "📸 {caption} #AI #ImageCaption",
|
| 129 |
+
"Technical": "Visual Analysis: {caption}",
|
| 130 |
+
"None": "{caption}"
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# ============================================================================
|
| 135 |
+
# CACHE CONFIGURATION
|
| 136 |
+
# ============================================================================
|
| 137 |
+
|
| 138 |
+
@dataclass(frozen=True)
|
| 139 |
+
class CacheConfig:
|
| 140 |
+
"""Configuration for caching system"""
|
| 141 |
+
|
| 142 |
+
# Cache Settings
|
| 143 |
+
MAX_CACHE_SIZE: int = 100 # Maximum number of cached items
|
| 144 |
+
CACHE_TTL_SECONDS: int = 3600 # Time to live: 1 hour
|
| 145 |
+
|
| 146 |
+
# Cache Keys
|
| 147 |
+
ENABLE_CAPTION_CACHE: bool = True
|
| 148 |
+
CACHE_KEY_ALGO: str = "md5" # Hashing algorithm for cache keys
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ============================================================================
|
| 152 |
+
# ANALYTICS CONFIGURATION
|
| 153 |
+
# ============================================================================
|
| 154 |
+
|
| 155 |
+
@dataclass(frozen=True)
|
| 156 |
+
class AnalyticsConfig:
|
| 157 |
+
"""Configuration for usage analytics"""
|
| 158 |
+
|
| 159 |
+
# Storage
|
| 160 |
+
ANALYTICS_FILE: Path = ANALYTICS_FILE
|
| 161 |
+
SAVE_INTERVAL_SECONDS: int = 30 # Auto-save every 30 seconds
|
| 162 |
+
|
| 163 |
+
# Metrics to Track
|
| 164 |
+
TRACK_PROCESSING_TIME: bool = True
|
| 165 |
+
TRACK_STYLE_USAGE: bool = True
|
| 166 |
+
TRACK_MODEL_USAGE: bool = True
|
| 167 |
+
TRACK_ERROR_RATE: bool = True
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ============================================================================
|
| 171 |
+
# GRADIO UI CONFIGURATION
|
| 172 |
+
# ============================================================================
|
| 173 |
+
|
| 174 |
+
@dataclass(frozen=True)
|
| 175 |
+
class UIConfig:
|
| 176 |
+
"""Configuration for Gradio interface"""
|
| 177 |
+
|
| 178 |
+
# App Metadata
|
| 179 |
+
TITLE: str = "🖼️ AI Image Caption Generator"
|
| 180 |
+
DESCRIPTION: str = """
|
| 181 |
+
Generate professional image captions using state-of-the-art AI models.
|
| 182 |
+
Upload an image and choose your preferred style - get instant captions from both BLIP and GIT models.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
# UI Settings
|
| 186 |
+
THEME: str = "soft" # Gradio theme
|
| 187 |
+
SHOW_API: bool = False
|
| 188 |
+
SHOW_ERROR: bool = True
|
| 189 |
+
|
| 190 |
+
# Component Settings
|
| 191 |
+
IMAGE_HEIGHT: int = 400
|
| 192 |
+
MAX_QUEUE_SIZE: int = 10
|
| 193 |
+
|
| 194 |
+
# Example Images
|
| 195 |
+
EXAMPLES_DIR: Path = STATIC_DIR / "images" / "examples"
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ============================================================================
|
| 199 |
+
# LOGGING CONFIGURATION
|
| 200 |
+
# ============================================================================
|
| 201 |
+
|
| 202 |
+
@dataclass(frozen=True)
|
| 203 |
+
class LogConfig:
|
| 204 |
+
"""Configuration for logging"""
|
| 205 |
+
|
| 206 |
+
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
| 207 |
+
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 208 |
+
LOG_DATE_FORMAT: str = "%Y-%m-%d %H:%M:%S"
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# ============================================================================
|
| 212 |
+
# PERFORMANCE CONFIGURATION
|
| 213 |
+
# ============================================================================
|
| 214 |
+
|
| 215 |
+
@dataclass(frozen=True)
|
| 216 |
+
class PerformanceConfig:
|
| 217 |
+
"""Configuration for performance optimization"""
|
| 218 |
+
|
| 219 |
+
# Processing Timeouts
|
| 220 |
+
MAX_PROCESSING_TIME_SECONDS: int = 30
|
| 221 |
+
|
| 222 |
+
# Model Loading
|
| 223 |
+
LAZY_LOAD_MODELS: bool = False # Load models on first use vs startup
|
| 224 |
+
|
| 225 |
+
# Batch Processing (future feature)
|
| 226 |
+
ENABLE_BATCH_PROCESSING: bool = False
|
| 227 |
+
MAX_BATCH_SIZE: int = 1
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# ============================================================================
|
| 231 |
+
# INSTANTIATE CONFIGURATIONS
|
| 232 |
+
# ============================================================================
|
| 233 |
+
|
| 234 |
+
# Create singleton instances
|
| 235 |
+
model_config = ModelConfig()
|
| 236 |
+
image_config = ImageConfig()
|
| 237 |
+
groq_config = GroqConfig()
|
| 238 |
+
style_config = StyleConfig()
|
| 239 |
+
cache_config = CacheConfig()
|
| 240 |
+
analytics_config = AnalyticsConfig()
|
| 241 |
+
ui_config = UIConfig()
|
| 242 |
+
log_config = LogConfig()
|
| 243 |
+
performance_config = PerformanceConfig()
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# ============================================================================
|
| 247 |
+
# VALIDATION
|
| 248 |
+
# ============================================================================
|
| 249 |
+
|
| 250 |
+
def validate_config() -> tuple[bool, list[str]]:
|
| 251 |
+
"""
|
| 252 |
+
Validate all configuration settings
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
tuple: (is_valid, list_of_errors)
|
| 256 |
+
"""
|
| 257 |
+
errors = []
|
| 258 |
+
|
| 259 |
+
# Check Groq API Key
|
| 260 |
+
if not groq_config.API_KEY:
|
| 261 |
+
errors.append("GROQ_API_KEY not found in environment variables")
|
| 262 |
+
|
| 263 |
+
# Check required directories
|
| 264 |
+
required_dirs = [CACHE_DIR, MODEL_CACHE_DIR]
|
| 265 |
+
for directory in required_dirs:
|
| 266 |
+
if not directory.exists():
|
| 267 |
+
errors.append(f"Required directory not found: {directory}")
|
| 268 |
+
|
| 269 |
+
# Validate image constraints
|
| 270 |
+
if image_config.MAX_DIMENSION < image_config.MIN_DIMENSION:
|
| 271 |
+
errors.append("MAX_DIMENSION must be greater than MIN_DIMENSION")
|
| 272 |
+
|
| 273 |
+
# Validate style options
|
| 274 |
+
if not style_config.STYLES:
|
| 275 |
+
errors.append("No style options configured")
|
| 276 |
+
|
| 277 |
+
if style_config.DEFAULT_STYLE not in style_config.STYLES:
|
| 278 |
+
errors.append(f"Default style '{style_config.DEFAULT_STYLE}' not in available styles")
|
| 279 |
+
|
| 280 |
+
return len(errors) == 0, errors
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# ============================================================================
|
| 284 |
+
# CONFIGURATION SUMMARY
|
| 285 |
+
# ============================================================================
|
| 286 |
+
|
| 287 |
+
def print_config_summary() -> None:
|
| 288 |
+
"""Print configuration summary for debugging"""
|
| 289 |
+
print("=" * 60)
|
| 290 |
+
print("AI IMAGE CAPTION GENERATOR - CONFIGURATION SUMMARY")
|
| 291 |
+
print("=" * 60)
|
| 292 |
+
print(f"Project Root: {PROJECT_ROOT}")
|
| 293 |
+
print(f"Cache Directory: {CACHE_DIR}")
|
| 294 |
+
print(f"Model Cache: {MODEL_CACHE_DIR}")
|
| 295 |
+
print(f"\nModels:")
|
| 296 |
+
print(f" - BLIP: {model_config.BLIP_MODEL_NAME}")
|
| 297 |
+
print(f" - GIT: {model_config.GIT_MODEL_NAME}")
|
| 298 |
+
print(f" - Device: {model_config.DEVICE}")
|
| 299 |
+
print(f"\nGroq API:")
|
| 300 |
+
print(f" - Model: {groq_config.MODEL_NAME}")
|
| 301 |
+
print(f" - API Key: {'✓ Configured' if groq_config.API_KEY else '✗ Missing'}")
|
| 302 |
+
print(f"\nImage Processing:")
|
| 303 |
+
print(f" - Max Size: {image_config.MAX_FILE_SIZE_MB}MB")
|
| 304 |
+
print(f" - Max Dimension: {image_config.MAX_DIMENSION}px")
|
| 305 |
+
print(f" - Formats: {', '.join(image_config.ALLOWED_FORMATS)}")
|
| 306 |
+
print(f"\nStyle Options: {len(style_config.STYLES)}")
|
| 307 |
+
for style in style_config.STYLES.keys():
|
| 308 |
+
print(f" - {style}")
|
| 309 |
+
print(f"\nCache: {cache_config.MAX_CACHE_SIZE} items")
|
| 310 |
+
print(f"Analytics: {analytics_config.ANALYTICS_FILE}")
|
| 311 |
+
print("=" * 60)
|
| 312 |
+
|
| 313 |
+
# Validate configuration
|
| 314 |
+
is_valid, errors = validate_config()
|
| 315 |
+
if not is_valid:
|
| 316 |
+
print("\n⚠️ CONFIGURATION ERRORS:")
|
| 317 |
+
for error in errors:
|
| 318 |
+
print(f" - {error}")
|
| 319 |
+
print("=" * 60)
|
| 320 |
+
else:
|
| 321 |
+
print("\n✓ Configuration validated successfully")
|
| 322 |
+
print("=" * 60)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
if __name__ == "__main__":
|
| 326 |
+
# Run configuration validation when executed directly
|
| 327 |
+
print_config_summary()
|
docker-compose.yml
ADDED
|
File without changes
|
requirements.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core Framework
|
| 2 |
+
torch==2.1.0
|
| 3 |
+
torchvision==0.16.0
|
| 4 |
+
transformers==4.35.0
|
| 5 |
+
gradio==4.8.0
|
| 6 |
+
accelerate==0.25.0
|
| 7 |
+
|
| 8 |
+
# Image Processing
|
| 9 |
+
Pillow==10.0.1
|
| 10 |
+
opencv-python==4.8.1.78
|
| 11 |
+
|
| 12 |
+
# API Integration
|
| 13 |
+
groq>=0.33.0
|
| 14 |
+
requests==2.31.0
|
| 15 |
+
|
| 16 |
+
# Utilities
|
| 17 |
+
python-dotenv==1.0.0
|
| 18 |
+
numpy==1.24.3
|
| 19 |
+
tqdm==4.66.1
|
| 20 |
+
|
| 21 |
+
# Development
|
| 22 |
+
pytest==7.4.3
|
| 23 |
+
black==23.9.1
|
| 24 |
+
flake8==6.1.0
|
src/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Source Package Initialization
|
| 3 |
+
|
| 4 |
+
Makes src a proper Python package and exposes key components.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
__version__ = "1.0.0"
|
| 8 |
+
__author__ = "AI Caption Generator Team"
|
| 9 |
+
|
| 10 |
+
# This file makes src a proper Python package
|
src/models/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Models Package
|
| 3 |
+
|
| 4 |
+
Provides caption generation and styling models.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .caption_model import (
|
| 8 |
+
CaptionModel,
|
| 9 |
+
BLIPModel,
|
| 10 |
+
GITModel,
|
| 11 |
+
CaptionModelManager,
|
| 12 |
+
CaptionModelError,
|
| 13 |
+
get_model_manager
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from .style_model import (
|
| 17 |
+
StyleModel,
|
| 18 |
+
StyleModelError,
|
| 19 |
+
get_style_model
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
# Caption Models
|
| 24 |
+
"CaptionModel",
|
| 25 |
+
"BLIPModel",
|
| 26 |
+
"GITModel",
|
| 27 |
+
"CaptionModelManager",
|
| 28 |
+
"CaptionModelError",
|
| 29 |
+
"get_model_manager",
|
| 30 |
+
|
| 31 |
+
# Style Model
|
| 32 |
+
"StyleModel",
|
| 33 |
+
"StyleModelError",
|
| 34 |
+
"get_style_model",
|
| 35 |
+
]
|
src/models/caption_model.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Caption Model Module
|
| 3 |
+
|
| 4 |
+
Manages BLIP and GIT models for image caption generation.
|
| 5 |
+
Handles model loading, inference, and memory management.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from typing import Optional, Dict, Tuple
|
| 11 |
+
from transformers import (
|
| 12 |
+
BlipProcessor,
|
| 13 |
+
BlipForConditionalGeneration,
|
| 14 |
+
AutoProcessor,
|
| 15 |
+
AutoModelForCausalLM
|
| 16 |
+
)
|
| 17 |
+
import gc
|
| 18 |
+
|
| 19 |
+
from config import model_config
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class CaptionModelError(Exception):
|
| 23 |
+
"""Custom exception for caption model errors"""
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class CaptionModel:
|
| 28 |
+
"""
|
| 29 |
+
Base class for caption generation models
|
| 30 |
+
|
| 31 |
+
Provides common interface for BLIP and GIT models
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, model_name: str, device: str = "cuda"):
|
| 35 |
+
"""
|
| 36 |
+
Initialize caption model
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
model_name: HuggingFace model identifier
|
| 40 |
+
device: Device to load model on (cuda/cpu)
|
| 41 |
+
"""
|
| 42 |
+
self.model_name = model_name
|
| 43 |
+
self.device = self._get_device(device)
|
| 44 |
+
self.processor = None
|
| 45 |
+
self.model = None
|
| 46 |
+
self._is_loaded = False
|
| 47 |
+
|
| 48 |
+
def _get_device(self, requested_device: str) -> str:
|
| 49 |
+
"""
|
| 50 |
+
Determine available device
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
requested_device: Requested device (cuda/cpu)
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
str: Available device
|
| 57 |
+
"""
|
| 58 |
+
if requested_device == "cuda" and torch.cuda.is_available():
|
| 59 |
+
return "cuda"
|
| 60 |
+
return "cpu"
|
| 61 |
+
|
| 62 |
+
def load(self) -> bool:
|
| 63 |
+
"""
|
| 64 |
+
Load model into memory
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
bool: True if successful
|
| 68 |
+
"""
|
| 69 |
+
raise NotImplementedError("Subclass must implement load()")
|
| 70 |
+
|
| 71 |
+
def generate_caption(
|
| 72 |
+
self,
|
| 73 |
+
image: Image.Image,
|
| 74 |
+
max_length: int = 50,
|
| 75 |
+
num_beams: int = 3
|
| 76 |
+
) -> str:
|
| 77 |
+
"""
|
| 78 |
+
Generate caption for image
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
image: PIL Image
|
| 82 |
+
max_length: Maximum caption length
|
| 83 |
+
num_beams: Number of beams for beam search
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
str: Generated caption
|
| 87 |
+
"""
|
| 88 |
+
raise NotImplementedError("Subclass must implement generate_caption()")
|
| 89 |
+
|
| 90 |
+
def unload(self) -> None:
|
| 91 |
+
"""Unload model from memory"""
|
| 92 |
+
if self.model is not None:
|
| 93 |
+
del self.model
|
| 94 |
+
self.model = None
|
| 95 |
+
if self.processor is not None:
|
| 96 |
+
del self.processor
|
| 97 |
+
self.processor = None
|
| 98 |
+
|
| 99 |
+
gc.collect()
|
| 100 |
+
if self.device == "cuda":
|
| 101 |
+
torch.cuda.empty_cache()
|
| 102 |
+
|
| 103 |
+
self._is_loaded = False
|
| 104 |
+
|
| 105 |
+
def is_loaded(self) -> bool:
|
| 106 |
+
"""Check if model is loaded"""
|
| 107 |
+
return self._is_loaded
|
| 108 |
+
|
| 109 |
+
def get_info(self) -> dict:
|
| 110 |
+
"""Get model information"""
|
| 111 |
+
return {
|
| 112 |
+
"model_name": self.model_name,
|
| 113 |
+
"device": self.device,
|
| 114 |
+
"is_loaded": self._is_loaded
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class BLIPModel(CaptionModel):
|
| 119 |
+
"""
|
| 120 |
+
BLIP (Bootstrapping Language-Image Pre-training) model
|
| 121 |
+
|
| 122 |
+
Fast and efficient model for image captioning
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(self, device: str = "cuda"):
|
| 126 |
+
"""Initialize BLIP model"""
|
| 127 |
+
super().__init__(model_config.BLIP_MODEL_NAME, device)
|
| 128 |
+
self.max_length = model_config.BLIP_MAX_LENGTH
|
| 129 |
+
self.num_beams = model_config.BLIP_NUM_BEAMS
|
| 130 |
+
|
| 131 |
+
def load(self) -> bool:
|
| 132 |
+
"""
|
| 133 |
+
Load BLIP model and processor
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
bool: True if successful
|
| 137 |
+
"""
|
| 138 |
+
try:
|
| 139 |
+
print(f"Loading BLIP model on {self.device}...")
|
| 140 |
+
|
| 141 |
+
# Load processor
|
| 142 |
+
self.processor = BlipProcessor.from_pretrained(
|
| 143 |
+
self.model_name,
|
| 144 |
+
cache_dir=model_config.MODEL_CACHE_DIR
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Load model
|
| 148 |
+
self.model = BlipForConditionalGeneration.from_pretrained(
|
| 149 |
+
self.model_name,
|
| 150 |
+
cache_dir=model_config.MODEL_CACHE_DIR,
|
| 151 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
| 152 |
+
).to(self.device)
|
| 153 |
+
|
| 154 |
+
# Set to evaluation mode
|
| 155 |
+
self.model.eval()
|
| 156 |
+
|
| 157 |
+
self._is_loaded = True
|
| 158 |
+
print(f"✓ BLIP model loaded successfully on {self.device}")
|
| 159 |
+
return True
|
| 160 |
+
|
| 161 |
+
except Exception as e:
|
| 162 |
+
print(f"Error loading BLIP model: {e}")
|
| 163 |
+
self._is_loaded = False
|
| 164 |
+
return False
|
| 165 |
+
|
| 166 |
+
def generate_caption(
|
| 167 |
+
self,
|
| 168 |
+
image: Image.Image,
|
| 169 |
+
max_length: Optional[int] = None,
|
| 170 |
+
num_beams: Optional[int] = None
|
| 171 |
+
) -> str:
|
| 172 |
+
"""
|
| 173 |
+
Generate caption using BLIP
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
image: PIL Image
|
| 177 |
+
max_length: Maximum caption length
|
| 178 |
+
num_beams: Number of beams for beam search
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
str: Generated caption
|
| 182 |
+
|
| 183 |
+
Raises:
|
| 184 |
+
CaptionModelError: If generation fails
|
| 185 |
+
"""
|
| 186 |
+
if not self._is_loaded:
|
| 187 |
+
raise CaptionModelError("BLIP model not loaded")
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
# Use default values if not provided
|
| 191 |
+
max_length = max_length or self.max_length
|
| 192 |
+
num_beams = num_beams or self.num_beams
|
| 193 |
+
|
| 194 |
+
# Preprocess image
|
| 195 |
+
inputs = self.processor(
|
| 196 |
+
images=image,
|
| 197 |
+
return_tensors="pt"
|
| 198 |
+
).to(self.device)
|
| 199 |
+
|
| 200 |
+
# Generate caption
|
| 201 |
+
with torch.no_grad():
|
| 202 |
+
output_ids = self.model.generate(
|
| 203 |
+
**inputs,
|
| 204 |
+
max_length=max_length,
|
| 205 |
+
num_beams=num_beams,
|
| 206 |
+
early_stopping=True
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Decode caption
|
| 210 |
+
caption = self.processor.decode(
|
| 211 |
+
output_ids[0],
|
| 212 |
+
skip_special_tokens=True
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
return caption.strip()
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
raise CaptionModelError(f"BLIP caption generation failed: {e}")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class GITModel(CaptionModel):
|
| 222 |
+
"""
|
| 223 |
+
GIT (Generative Image-to-text Transformer) model
|
| 224 |
+
|
| 225 |
+
More detailed and accurate captions compared to BLIP
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
def __init__(self, device: str = "cuda"):
|
| 229 |
+
"""Initialize GIT model"""
|
| 230 |
+
super().__init__(model_config.GIT_MODEL_NAME, device)
|
| 231 |
+
self.max_length = model_config.GIT_MAX_LENGTH
|
| 232 |
+
self.num_beams = model_config.GIT_NUM_BEAMS
|
| 233 |
+
|
| 234 |
+
def load(self) -> bool:
|
| 235 |
+
"""
|
| 236 |
+
Load GIT model and processor
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
bool: True if successful
|
| 240 |
+
"""
|
| 241 |
+
try:
|
| 242 |
+
print(f"Loading GIT model on {self.device}...")
|
| 243 |
+
|
| 244 |
+
# Load processor
|
| 245 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 246 |
+
self.model_name,
|
| 247 |
+
cache_dir=model_config.MODEL_CACHE_DIR
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Load model
|
| 251 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 252 |
+
self.model_name,
|
| 253 |
+
cache_dir=model_config.MODEL_CACHE_DIR,
|
| 254 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
| 255 |
+
).to(self.device)
|
| 256 |
+
|
| 257 |
+
# Set to evaluation mode
|
| 258 |
+
self.model.eval()
|
| 259 |
+
|
| 260 |
+
self._is_loaded = True
|
| 261 |
+
print(f"✓ GIT model loaded successfully on {self.device}")
|
| 262 |
+
return True
|
| 263 |
+
|
| 264 |
+
except Exception as e:
|
| 265 |
+
print(f"Error loading GIT model: {e}")
|
| 266 |
+
self._is_loaded = False
|
| 267 |
+
return False
|
| 268 |
+
|
| 269 |
+
def generate_caption(
|
| 270 |
+
self,
|
| 271 |
+
image: Image.Image,
|
| 272 |
+
max_length: Optional[int] = None,
|
| 273 |
+
num_beams: Optional[int] = None
|
| 274 |
+
) -> str:
|
| 275 |
+
"""
|
| 276 |
+
Generate caption using GIT
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
image: PIL Image
|
| 280 |
+
max_length: Maximum caption length
|
| 281 |
+
num_beams: Number of beams for beam search
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
str: Generated caption
|
| 285 |
+
|
| 286 |
+
Raises:
|
| 287 |
+
CaptionModelError: If generation fails
|
| 288 |
+
"""
|
| 289 |
+
if not self._is_loaded:
|
| 290 |
+
raise CaptionModelError("GIT model not loaded")
|
| 291 |
+
|
| 292 |
+
try:
|
| 293 |
+
# Use default values if not provided
|
| 294 |
+
max_length = max_length or self.max_length
|
| 295 |
+
num_beams = num_beams or self.num_beams
|
| 296 |
+
|
| 297 |
+
# Preprocess image
|
| 298 |
+
inputs = self.processor(
|
| 299 |
+
images=image,
|
| 300 |
+
return_tensors="pt"
|
| 301 |
+
).to(self.device)
|
| 302 |
+
|
| 303 |
+
# Generate caption
|
| 304 |
+
with torch.no_grad():
|
| 305 |
+
output_ids = self.model.generate(
|
| 306 |
+
pixel_values=inputs.pixel_values,
|
| 307 |
+
max_length=max_length,
|
| 308 |
+
num_beams=num_beams,
|
| 309 |
+
early_stopping=True
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Decode caption
|
| 313 |
+
caption = self.processor.batch_decode(
|
| 314 |
+
output_ids,
|
| 315 |
+
skip_special_tokens=True
|
| 316 |
+
)[0]
|
| 317 |
+
|
| 318 |
+
return caption.strip()
|
| 319 |
+
|
| 320 |
+
except Exception as e:
|
| 321 |
+
raise CaptionModelError(f"GIT caption generation failed: {e}")
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class CaptionModelManager:
|
| 325 |
+
"""
|
| 326 |
+
Manager for both BLIP and GIT models
|
| 327 |
+
|
| 328 |
+
Provides unified interface and handles model lifecycle
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
def __init__(self, device: Optional[str] = None):
|
| 332 |
+
"""
|
| 333 |
+
Initialize model manager
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
device: Device to use (cuda/cpu), auto-detects if None
|
| 337 |
+
"""
|
| 338 |
+
self.device = device or model_config.DEVICE
|
| 339 |
+
|
| 340 |
+
# Initialize models
|
| 341 |
+
self.blip_model = BLIPModel(self.device)
|
| 342 |
+
self.git_model = GITModel(self.device)
|
| 343 |
+
|
| 344 |
+
# Track which models are loaded
|
| 345 |
+
self._loaded_models = set()
|
| 346 |
+
|
| 347 |
+
def load_all_models(self) -> Tuple[bool, bool]:
|
| 348 |
+
"""
|
| 349 |
+
Load both models
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
Tuple[bool, bool]: (blip_success, git_success)
|
| 353 |
+
"""
|
| 354 |
+
blip_success = self.blip_model.load()
|
| 355 |
+
if blip_success:
|
| 356 |
+
self._loaded_models.add("blip")
|
| 357 |
+
|
| 358 |
+
git_success = self.git_model.load()
|
| 359 |
+
if git_success:
|
| 360 |
+
self._loaded_models.add("git")
|
| 361 |
+
|
| 362 |
+
return blip_success, git_success
|
| 363 |
+
|
| 364 |
+
def load_model(self, model_name: str) -> bool:
|
| 365 |
+
"""
|
| 366 |
+
Load specific model
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
model_name: Model to load ("blip" or "git")
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
bool: True if successful
|
| 373 |
+
"""
|
| 374 |
+
if model_name.lower() == "blip":
|
| 375 |
+
success = self.blip_model.load()
|
| 376 |
+
if success:
|
| 377 |
+
self._loaded_models.add("blip")
|
| 378 |
+
return success
|
| 379 |
+
elif model_name.lower() == "git":
|
| 380 |
+
success = self.git_model.load()
|
| 381 |
+
if success:
|
| 382 |
+
self._loaded_models.add("git")
|
| 383 |
+
return success
|
| 384 |
+
else:
|
| 385 |
+
raise ValueError(f"Unknown model: {model_name}")
|
| 386 |
+
|
| 387 |
+
def generate_captions(
|
| 388 |
+
self,
|
| 389 |
+
image: Image.Image
|
| 390 |
+
) -> Dict[str, str]:
|
| 391 |
+
"""
|
| 392 |
+
Generate captions from all loaded models
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
image: PIL Image
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
Dict[str, str]: Captions from each model
|
| 399 |
+
"""
|
| 400 |
+
captions = {}
|
| 401 |
+
|
| 402 |
+
if "blip" in self._loaded_models:
|
| 403 |
+
try:
|
| 404 |
+
captions["blip"] = self.blip_model.generate_caption(image)
|
| 405 |
+
except Exception as e:
|
| 406 |
+
captions["blip"] = f"Error: {str(e)}"
|
| 407 |
+
|
| 408 |
+
if "git" in self._loaded_models:
|
| 409 |
+
try:
|
| 410 |
+
captions["git"] = self.git_model.generate_caption(image)
|
| 411 |
+
except Exception as e:
|
| 412 |
+
captions["git"] = f"Error: {str(e)}"
|
| 413 |
+
|
| 414 |
+
return captions
|
| 415 |
+
|
| 416 |
+
def unload_all_models(self) -> None:
|
| 417 |
+
"""Unload all models from memory"""
|
| 418 |
+
self.blip_model.unload()
|
| 419 |
+
self.git_model.unload()
|
| 420 |
+
self._loaded_models.clear()
|
| 421 |
+
|
| 422 |
+
def get_status(self) -> dict:
|
| 423 |
+
"""Get status of all models"""
|
| 424 |
+
return {
|
| 425 |
+
"device": self.device,
|
| 426 |
+
"blip": {
|
| 427 |
+
"loaded": self.blip_model.is_loaded(),
|
| 428 |
+
"info": self.blip_model.get_info()
|
| 429 |
+
},
|
| 430 |
+
"git": {
|
| 431 |
+
"loaded": self.git_model.is_loaded(),
|
| 432 |
+
"info": self.git_model.get_info()
|
| 433 |
+
},
|
| 434 |
+
"loaded_models": list(self._loaded_models)
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
# Singleton instance
|
| 439 |
+
_model_manager = None
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def get_model_manager() -> CaptionModelManager:
|
| 443 |
+
"""Get singleton CaptionModelManager instance"""
|
| 444 |
+
global _model_manager
|
| 445 |
+
if _model_manager is None:
|
| 446 |
+
_model_manager = CaptionModelManager()
|
| 447 |
+
return _model_manager
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
if __name__ == "__main__":
|
| 451 |
+
# Test the caption models
|
| 452 |
+
print("=" * 60)
|
| 453 |
+
print("CAPTION MODELS - TEST MODE")
|
| 454 |
+
print("=" * 60)
|
| 455 |
+
|
| 456 |
+
# Initialize manager
|
| 457 |
+
manager = CaptionModelManager()
|
| 458 |
+
print(f"\n✓ Model manager initialized")
|
| 459 |
+
print(f" Device: {manager.device}")
|
| 460 |
+
|
| 461 |
+
print("\n" + "=" * 60)
|
| 462 |
+
print("Loading models (this may take a few minutes)...")
|
| 463 |
+
print("=" * 60)
|
| 464 |
+
|
| 465 |
+
# Load models
|
| 466 |
+
blip_success, git_success = manager.load_all_models()
|
| 467 |
+
|
| 468 |
+
print(f"\nBLIP: {'✓ Loaded' if blip_success else '✗ Failed'}")
|
| 469 |
+
print(f"GIT: {'✓ Loaded' if git_success else '✗ Failed'}")
|
| 470 |
+
|
| 471 |
+
print("\n" + "=" * 60)
|
| 472 |
+
print("Model Status:")
|
| 473 |
+
print("=" * 60)
|
| 474 |
+
status = manager.get_status()
|
| 475 |
+
for key, value in status.items():
|
| 476 |
+
if isinstance(value, dict):
|
| 477 |
+
print(f"{key}:")
|
| 478 |
+
for k, v in value.items():
|
| 479 |
+
print(f" {k}: {v}")
|
| 480 |
+
else:
|
| 481 |
+
print(f"{key}: {value}")
|
| 482 |
+
|
| 483 |
+
print("\n" + "=" * 60)
|
| 484 |
+
print("✓ Caption models test complete")
|
| 485 |
+
print("=" * 60)
|
| 486 |
+
print("\nTo test caption generation, provide a test image:")
|
| 487 |
+
print(" from PIL import Image")
|
| 488 |
+
print(" img = Image.open('your_image.jpg')")
|
| 489 |
+
print(" captions = manager.generate_captions(img)")
|
| 490 |
+
print(" print(captions)")
|
src/models/style_model.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Style Model Module
|
| 3 |
+
|
| 4 |
+
Handles caption styling using Groq API with fallback mechanisms.
|
| 5 |
+
Applies different writing styles to generated captions.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from groq import Groq
|
| 11 |
+
import requests
|
| 12 |
+
|
| 13 |
+
from config import groq_config, style_config
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class StyleModelError(Exception):
|
| 17 |
+
"""Custom exception for style model errors"""
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class StyleModel:
|
| 22 |
+
"""
|
| 23 |
+
Caption styling using Groq LLM API
|
| 24 |
+
|
| 25 |
+
Features:
|
| 26 |
+
- Multiple style options
|
| 27 |
+
- Automatic retry logic
|
| 28 |
+
- Fallback to rule-based styling
|
| 29 |
+
- Rate limiting handling
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, api_key: Optional[str] = None):
|
| 33 |
+
"""
|
| 34 |
+
Initialize style model
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
api_key: Groq API key (uses config if not provided)
|
| 38 |
+
"""
|
| 39 |
+
self.api_key = api_key or groq_config.API_KEY
|
| 40 |
+
self.model_name = groq_config.MODEL_NAME
|
| 41 |
+
self.max_tokens = groq_config.MAX_TOKENS
|
| 42 |
+
self.temperature = groq_config.TEMPERATURE
|
| 43 |
+
self.timeout = groq_config.TIMEOUT_SECONDS
|
| 44 |
+
|
| 45 |
+
# Initialize Groq client
|
| 46 |
+
if self.api_key:
|
| 47 |
+
try:
|
| 48 |
+
self.client = Groq(
|
| 49 |
+
api_key=self.api_key
|
| 50 |
+
)
|
| 51 |
+
self._api_available = True
|
| 52 |
+
_ = self.client.models.list()
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"Warning: Groq client initialization failed: {e}")
|
| 55 |
+
print(f"Attempting alternative initialization...")
|
| 56 |
+
try:
|
| 57 |
+
# Alternative: Create client without extra params
|
| 58 |
+
import groq
|
| 59 |
+
self.client = groq.Client(api_key=self.api_key)
|
| 60 |
+
self._api_available = True
|
| 61 |
+
except Exception as e2:
|
| 62 |
+
print(f"Alternative initialization also failed: {e2}")
|
| 63 |
+
self.client = None
|
| 64 |
+
self._api_available = False
|
| 65 |
+
else:
|
| 66 |
+
print("Warning: No Groq API key provided")
|
| 67 |
+
self.client = None
|
| 68 |
+
self._api_available = False
|
| 69 |
+
|
| 70 |
+
# Retry configuration
|
| 71 |
+
self.max_retries = groq_config.MAX_RETRIES
|
| 72 |
+
self.retry_delay = groq_config.RETRY_DELAY_SECONDS
|
| 73 |
+
|
| 74 |
+
def style_caption(
|
| 75 |
+
self,
|
| 76 |
+
caption: str,
|
| 77 |
+
style: str = "Professional"
|
| 78 |
+
) -> str:
|
| 79 |
+
"""
|
| 80 |
+
Apply style to caption
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
caption: Original caption
|
| 84 |
+
style: Style to apply
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
str: Styled caption
|
| 88 |
+
"""
|
| 89 |
+
# If "None" style or no API, return original
|
| 90 |
+
if style == "None" or not self._api_available:
|
| 91 |
+
if style != "None":
|
| 92 |
+
# Use fallback styling if API unavailable
|
| 93 |
+
return self._fallback_style(caption, style)
|
| 94 |
+
return caption
|
| 95 |
+
|
| 96 |
+
# Try API styling with retries
|
| 97 |
+
for attempt in range(self.max_retries):
|
| 98 |
+
try:
|
| 99 |
+
styled_caption = self._style_with_api(caption, style)
|
| 100 |
+
return styled_caption
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
print(f"API styling attempt {attempt + 1} failed: {e}")
|
| 104 |
+
|
| 105 |
+
# If last attempt, use fallback
|
| 106 |
+
if attempt == self.max_retries - 1:
|
| 107 |
+
print(f"Using fallback styling for: {style}")
|
| 108 |
+
return self._fallback_style(caption, style)
|
| 109 |
+
|
| 110 |
+
# Wait before retry
|
| 111 |
+
time.sleep(self.retry_delay)
|
| 112 |
+
|
| 113 |
+
# Fallback if all retries failed
|
| 114 |
+
return self._fallback_style(caption, style)
|
| 115 |
+
|
| 116 |
+
def _style_with_api(self, caption: str, style: str) -> str:
|
| 117 |
+
"""
|
| 118 |
+
Style caption using Groq API
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
caption: Original caption
|
| 122 |
+
style: Style to apply
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
str: Styled caption
|
| 126 |
+
|
| 127 |
+
Raises:
|
| 128 |
+
StyleModelError: If API call fails
|
| 129 |
+
"""
|
| 130 |
+
if not self._api_available:
|
| 131 |
+
raise StyleModelError("API not available")
|
| 132 |
+
|
| 133 |
+
# Get style prompt
|
| 134 |
+
style_prompt = style_config.STYLES.get(
|
| 135 |
+
style,
|
| 136 |
+
style_config.STYLES[style_config.DEFAULT_STYLE]
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Construct messages
|
| 140 |
+
messages = [
|
| 141 |
+
{
|
| 142 |
+
"role": "system",
|
| 143 |
+
"content": "You are an expert at rewriting image captions in different styles. Keep the core meaning but adapt the tone and style as requested. Be concise."
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"role": "user",
|
| 147 |
+
"content": f"{style_prompt}\n\nOriginal caption: {caption}\n\nStyled caption:"
|
| 148 |
+
}
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
# Make API call
|
| 153 |
+
response = self.client.chat.completions.create(
|
| 154 |
+
model=self.model_name,
|
| 155 |
+
messages=messages,
|
| 156 |
+
max_tokens=self.max_tokens,
|
| 157 |
+
temperature=self.temperature,
|
| 158 |
+
top_p=groq_config.TOP_P,
|
| 159 |
+
timeout=self.timeout
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Extract styled caption
|
| 163 |
+
styled_caption = response.choices[0].message.content.strip()
|
| 164 |
+
|
| 165 |
+
# Clean up common artifacts
|
| 166 |
+
styled_caption = self._clean_response(styled_caption)
|
| 167 |
+
|
| 168 |
+
return styled_caption
|
| 169 |
+
|
| 170 |
+
except requests.exceptions.Timeout:
|
| 171 |
+
raise StyleModelError("API request timed out")
|
| 172 |
+
except requests.exceptions.RequestException as e:
|
| 173 |
+
raise StyleModelError(f"API request failed: {e}")
|
| 174 |
+
except Exception as e:
|
| 175 |
+
raise StyleModelError(f"Unexpected error: {e}")
|
| 176 |
+
|
| 177 |
+
def _fallback_style(self, caption: str, style: str) -> str:
|
| 178 |
+
"""
|
| 179 |
+
Apply rule-based styling as fallback
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
caption: Original caption
|
| 183 |
+
style: Style to apply
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
str: Styled caption using templates
|
| 187 |
+
"""
|
| 188 |
+
template = style_config.FALLBACK_TEMPLATES.get(
|
| 189 |
+
style,
|
| 190 |
+
style_config.FALLBACK_TEMPLATES["Professional"]
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
return template.format(caption=caption)
|
| 194 |
+
|
| 195 |
+
def _clean_response(self, text: str) -> str:
|
| 196 |
+
"""
|
| 197 |
+
Clean up API response
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
text: Raw response text
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
str: Cleaned text
|
| 204 |
+
"""
|
| 205 |
+
# Remove common prefixes
|
| 206 |
+
prefixes = [
|
| 207 |
+
"Styled caption:",
|
| 208 |
+
"Caption:",
|
| 209 |
+
"Here's the styled caption:",
|
| 210 |
+
"Here is the caption:",
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
for prefix in prefixes:
|
| 214 |
+
if text.lower().startswith(prefix.lower()):
|
| 215 |
+
text = text[len(prefix):].strip()
|
| 216 |
+
|
| 217 |
+
# Remove quotes if the entire text is quoted
|
| 218 |
+
if (text.startswith('"') and text.endswith('"')) or \
|
| 219 |
+
(text.startswith("'") and text.endswith("'")):
|
| 220 |
+
text = text[1:-1]
|
| 221 |
+
|
| 222 |
+
return text.strip()
|
| 223 |
+
|
| 224 |
+
def batch_style_captions(
|
| 225 |
+
self,
|
| 226 |
+
captions: dict,
|
| 227 |
+
style: str = "Professional"
|
| 228 |
+
) -> dict:
|
| 229 |
+
"""
|
| 230 |
+
Style multiple captions at once
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
captions: Dictionary of {model_name: caption}
|
| 234 |
+
style: Style to apply
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
dict: Dictionary of {model_name: styled_caption}
|
| 238 |
+
"""
|
| 239 |
+
styled_captions = {}
|
| 240 |
+
|
| 241 |
+
for model_name, caption in captions.items():
|
| 242 |
+
try:
|
| 243 |
+
styled_caption = self.style_caption(caption, style)
|
| 244 |
+
styled_captions[model_name] = styled_caption
|
| 245 |
+
except Exception as e:
|
| 246 |
+
print(f"Error styling {model_name} caption: {e}")
|
| 247 |
+
# Use original caption on error
|
| 248 |
+
styled_captions[model_name] = caption
|
| 249 |
+
|
| 250 |
+
return styled_captions
|
| 251 |
+
|
| 252 |
+
def is_api_available(self) -> bool:
|
| 253 |
+
"""Check if API is available"""
|
| 254 |
+
return self._api_available
|
| 255 |
+
|
| 256 |
+
def test_connection(self) -> bool:
|
| 257 |
+
"""
|
| 258 |
+
Test API connection
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
bool: True if API is working
|
| 262 |
+
"""
|
| 263 |
+
if not self._api_available:
|
| 264 |
+
return False
|
| 265 |
+
|
| 266 |
+
try:
|
| 267 |
+
# Simple test call
|
| 268 |
+
response = self.client.chat.completions.create(
|
| 269 |
+
model=self.model_name,
|
| 270 |
+
messages=[
|
| 271 |
+
{"role": "user", "content": "Hello"}
|
| 272 |
+
],
|
| 273 |
+
max_tokens=10,
|
| 274 |
+
timeout=5
|
| 275 |
+
)
|
| 276 |
+
return True
|
| 277 |
+
except Exception as e:
|
| 278 |
+
print(f"API connection test failed: {e}")
|
| 279 |
+
return False
|
| 280 |
+
|
| 281 |
+
def get_available_styles(self) -> list:
|
| 282 |
+
"""Get list of available styles"""
|
| 283 |
+
return list(style_config.STYLES.keys())
|
| 284 |
+
|
| 285 |
+
def get_info(self) -> dict:
|
| 286 |
+
"""Get model information"""
|
| 287 |
+
return {
|
| 288 |
+
"model_name": self.model_name,
|
| 289 |
+
"api_available": self._api_available,
|
| 290 |
+
"max_tokens": self.max_tokens,
|
| 291 |
+
"temperature": self.temperature,
|
| 292 |
+
"available_styles": self.get_available_styles()
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
# Singleton instance
|
| 297 |
+
_style_model = None
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def get_style_model() -> StyleModel:
|
| 301 |
+
"""Get singleton StyleModel instance"""
|
| 302 |
+
global _style_model
|
| 303 |
+
if _style_model is None:
|
| 304 |
+
_style_model = StyleModel()
|
| 305 |
+
return _style_model
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
# Test the style model
|
| 310 |
+
print("=" * 60)
|
| 311 |
+
print("STYLE MODEL - TEST MODE")
|
| 312 |
+
print("=" * 60)
|
| 313 |
+
|
| 314 |
+
# Initialize model
|
| 315 |
+
style_model = StyleModel()
|
| 316 |
+
|
| 317 |
+
print(f"\n✓ Style model initialized")
|
| 318 |
+
print(f" API Available: {style_model.is_api_available()}")
|
| 319 |
+
print(f" Model: {style_model.model_name}")
|
| 320 |
+
|
| 321 |
+
# Get info
|
| 322 |
+
print("\nModel Info:")
|
| 323 |
+
info = style_model.get_info()
|
| 324 |
+
for key, value in info.items():
|
| 325 |
+
if isinstance(value, list):
|
| 326 |
+
print(f" {key}:")
|
| 327 |
+
for item in value:
|
| 328 |
+
print(f" - {item}")
|
| 329 |
+
else:
|
| 330 |
+
print(f" {key}: {value}")
|
| 331 |
+
|
| 332 |
+
# Test connection if API available
|
| 333 |
+
if style_model.is_api_available():
|
| 334 |
+
print("\nTesting API connection...")
|
| 335 |
+
connection_ok = style_model.test_connection()
|
| 336 |
+
print(f" Connection: {'✓ Success' if connection_ok else '✗ Failed'}")
|
| 337 |
+
|
| 338 |
+
if connection_ok:
|
| 339 |
+
# Test styling
|
| 340 |
+
print("\nTesting caption styling:")
|
| 341 |
+
test_caption = "A cat sitting on a windowsill looking outside"
|
| 342 |
+
|
| 343 |
+
for style in ["Professional", "Creative", "Social Media"]:
|
| 344 |
+
print(f"\n {style}:")
|
| 345 |
+
try:
|
| 346 |
+
styled = style_model.style_caption(test_caption, style)
|
| 347 |
+
print(f" Original: {test_caption}")
|
| 348 |
+
print(f" Styled: {styled}")
|
| 349 |
+
except Exception as e:
|
| 350 |
+
print(f" Error: {e}")
|
| 351 |
+
else:
|
| 352 |
+
print("\n⚠️ API not available, testing fallback styling:")
|
| 353 |
+
test_caption = "A cat sitting on a windowsill looking outside"
|
| 354 |
+
|
| 355 |
+
for style in ["Professional", "Creative", "Social Media"]:
|
| 356 |
+
styled = style_model.style_caption(test_caption, style)
|
| 357 |
+
print(f"\n {style}: {styled}")
|
| 358 |
+
|
| 359 |
+
print("\n" + "=" * 60)
|
| 360 |
+
print("✓ Style model test complete")
|
| 361 |
+
print("=" * 60)
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utils Package
|
| 3 |
+
|
| 4 |
+
Provides utility functions for image processing, caching, and analytics.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .image_processor import (
|
| 8 |
+
ImageProcessor,
|
| 9 |
+
ImageProcessingError,
|
| 10 |
+
get_image_processor,
|
| 11 |
+
validate_image,
|
| 12 |
+
preprocess_image,
|
| 13 |
+
generate_image_hash
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from .cache_manager import (
|
| 17 |
+
CacheManager,
|
| 18 |
+
CaptionCache,
|
| 19 |
+
get_cache_manager,
|
| 20 |
+
get_caption_cache
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
from .analytics import (
|
| 24 |
+
AnalyticsManager,
|
| 25 |
+
get_analytics_manager,
|
| 26 |
+
record_generation,
|
| 27 |
+
get_stats,
|
| 28 |
+
get_summary,
|
| 29 |
+
get_display_stats
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
__all__ = [
|
| 33 |
+
# Image Processing
|
| 34 |
+
"ImageProcessor",
|
| 35 |
+
"ImageProcessingError",
|
| 36 |
+
"get_image_processor",
|
| 37 |
+
"validate_image",
|
| 38 |
+
"preprocess_image",
|
| 39 |
+
"generate_image_hash",
|
| 40 |
+
|
| 41 |
+
# Cache Management
|
| 42 |
+
"CacheManager",
|
| 43 |
+
"CaptionCache",
|
| 44 |
+
"get_cache_manager",
|
| 45 |
+
"get_caption_cache",
|
| 46 |
+
|
| 47 |
+
# Analytics
|
| 48 |
+
"AnalyticsManager",
|
| 49 |
+
"get_analytics_manager",
|
| 50 |
+
"record_generation",
|
| 51 |
+
"get_stats",
|
| 52 |
+
"get_summary",
|
| 53 |
+
"get_display_stats",
|
| 54 |
+
]
|
src/utils/analytics.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Analytics Module
|
| 3 |
+
|
| 4 |
+
Tracks usage statistics and performance metrics for the caption generator.
|
| 5 |
+
Provides insights into model usage, processing times, and popular styles.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import threading
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, Optional
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from dataclasses import dataclass, asdict
|
| 14 |
+
|
| 15 |
+
from config import analytics_config, style_config
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class AnalyticsData:
|
| 20 |
+
"""Container for analytics data"""
|
| 21 |
+
total_captions: int = 0
|
| 22 |
+
style_usage: Dict[str, int] = None
|
| 23 |
+
avg_processing_time: float = 0.0
|
| 24 |
+
total_processing_time: float = 0.0
|
| 25 |
+
model_usage: Dict[str, int] = None
|
| 26 |
+
error_count: int = 0
|
| 27 |
+
last_updated: Optional[str] = None
|
| 28 |
+
|
| 29 |
+
def __post_init__(self):
|
| 30 |
+
if self.style_usage is None:
|
| 31 |
+
self.style_usage = {style: 0 for style in style_config.STYLES.keys()}
|
| 32 |
+
if self.model_usage is None:
|
| 33 |
+
self.model_usage = {"blip": 0, "git": 0}
|
| 34 |
+
|
| 35 |
+
def to_dict(self) -> dict:
|
| 36 |
+
"""Convert to dictionary"""
|
| 37 |
+
return asdict(self)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class AnalyticsManager:
|
| 41 |
+
"""
|
| 42 |
+
Thread-safe analytics manager for tracking usage metrics
|
| 43 |
+
|
| 44 |
+
Features:
|
| 45 |
+
- Real-time metric tracking
|
| 46 |
+
- Persistent storage
|
| 47 |
+
- Thread-safe operations
|
| 48 |
+
- Automatic calculations
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, storage_path: Optional[Path] = None):
|
| 52 |
+
"""
|
| 53 |
+
Initialize analytics manager
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
storage_path: Path to analytics JSON file
|
| 57 |
+
"""
|
| 58 |
+
self.storage_path = storage_path or analytics_config.ANALYTICS_FILE
|
| 59 |
+
self._lock = threading.RLock()
|
| 60 |
+
|
| 61 |
+
# Load existing data or initialize new
|
| 62 |
+
self.data = self._load_data()
|
| 63 |
+
|
| 64 |
+
def _load_data(self) -> AnalyticsData:
|
| 65 |
+
"""
|
| 66 |
+
Load analytics data from file
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
AnalyticsData: Loaded or initialized data
|
| 70 |
+
"""
|
| 71 |
+
if self.storage_path.exists():
|
| 72 |
+
try:
|
| 73 |
+
with open(self.storage_path, 'r') as f:
|
| 74 |
+
data_dict = json.load(f)
|
| 75 |
+
return AnalyticsData(**data_dict)
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"Warning: Failed to load analytics: {e}")
|
| 78 |
+
return AnalyticsData()
|
| 79 |
+
else:
|
| 80 |
+
return AnalyticsData()
|
| 81 |
+
|
| 82 |
+
def _save_data(self) -> bool:
|
| 83 |
+
"""
|
| 84 |
+
Save analytics data to file
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
bool: True if successful
|
| 88 |
+
"""
|
| 89 |
+
try:
|
| 90 |
+
# Ensure directory exists
|
| 91 |
+
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
| 92 |
+
|
| 93 |
+
# Update timestamp
|
| 94 |
+
self.data.last_updated = datetime.now().isoformat()
|
| 95 |
+
|
| 96 |
+
# Write to file
|
| 97 |
+
with open(self.storage_path, 'w') as f:
|
| 98 |
+
json.dump(self.data.to_dict(), f, indent=4)
|
| 99 |
+
|
| 100 |
+
return True
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"Error saving analytics: {e}")
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
def record_caption_generation(
|
| 106 |
+
self,
|
| 107 |
+
model_name: str,
|
| 108 |
+
style: str,
|
| 109 |
+
processing_time: float,
|
| 110 |
+
success: bool = True
|
| 111 |
+
) -> None:
|
| 112 |
+
"""
|
| 113 |
+
Record a caption generation event
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
model_name: Name of the model used (blip/git)
|
| 117 |
+
style: Style applied
|
| 118 |
+
processing_time: Time taken in seconds
|
| 119 |
+
success: Whether generation was successful
|
| 120 |
+
"""
|
| 121 |
+
with self._lock:
|
| 122 |
+
if success:
|
| 123 |
+
# Increment counters
|
| 124 |
+
self.data.total_captions += 1
|
| 125 |
+
|
| 126 |
+
# Update style usage
|
| 127 |
+
if style in self.data.style_usage:
|
| 128 |
+
self.data.style_usage[style] += 1
|
| 129 |
+
|
| 130 |
+
# Update model usage
|
| 131 |
+
model_key = model_name.lower()
|
| 132 |
+
if model_key in self.data.model_usage:
|
| 133 |
+
self.data.model_usage[model_key] += 1
|
| 134 |
+
|
| 135 |
+
# Update processing time
|
| 136 |
+
self.data.total_processing_time += processing_time
|
| 137 |
+
self.data.avg_processing_time = (
|
| 138 |
+
self.data.total_processing_time / self.data.total_captions
|
| 139 |
+
)
|
| 140 |
+
else:
|
| 141 |
+
self.data.error_count += 1
|
| 142 |
+
|
| 143 |
+
# Save to disk
|
| 144 |
+
self._save_data()
|
| 145 |
+
|
| 146 |
+
def record_batch_generation(
|
| 147 |
+
self,
|
| 148 |
+
generations: list[dict]
|
| 149 |
+
) -> None:
|
| 150 |
+
"""
|
| 151 |
+
Record multiple caption generations at once
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
generations: List of generation records
|
| 155 |
+
Each record: {model_name, style, processing_time, success}
|
| 156 |
+
"""
|
| 157 |
+
with self._lock:
|
| 158 |
+
for gen in generations:
|
| 159 |
+
self.record_caption_generation(
|
| 160 |
+
model_name=gen.get("model_name", "unknown"),
|
| 161 |
+
style=gen.get("style", "None"),
|
| 162 |
+
processing_time=gen.get("processing_time", 0.0),
|
| 163 |
+
success=gen.get("success", True)
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def get_stats(self) -> dict:
|
| 167 |
+
"""
|
| 168 |
+
Get current statistics
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
dict: Current analytics data
|
| 172 |
+
"""
|
| 173 |
+
with self._lock:
|
| 174 |
+
return self.data.to_dict()
|
| 175 |
+
|
| 176 |
+
def get_summary(self) -> dict:
|
| 177 |
+
"""
|
| 178 |
+
Get formatted summary of analytics
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
dict: Human-readable summary
|
| 182 |
+
"""
|
| 183 |
+
with self._lock:
|
| 184 |
+
total = self.data.total_captions
|
| 185 |
+
|
| 186 |
+
# Calculate percentages for styles
|
| 187 |
+
style_percentages = {}
|
| 188 |
+
if total > 0:
|
| 189 |
+
for style, count in self.data.style_usage.items():
|
| 190 |
+
style_percentages[style] = round((count / total) * 100, 1)
|
| 191 |
+
|
| 192 |
+
# Calculate percentages for models
|
| 193 |
+
model_percentages = {}
|
| 194 |
+
if total > 0:
|
| 195 |
+
for model, count in self.data.model_usage.items():
|
| 196 |
+
model_percentages[model] = round((count / total) * 100, 1)
|
| 197 |
+
|
| 198 |
+
# Find most popular style
|
| 199 |
+
popular_style = max(
|
| 200 |
+
self.data.style_usage.items(),
|
| 201 |
+
key=lambda x: x[1]
|
| 202 |
+
)[0] if self.data.style_usage else "None"
|
| 203 |
+
|
| 204 |
+
return {
|
| 205 |
+
"total_captions": total,
|
| 206 |
+
"avg_processing_time": round(self.data.avg_processing_time, 2),
|
| 207 |
+
"error_rate": round(
|
| 208 |
+
(self.data.error_count / (total + self.data.error_count) * 100)
|
| 209 |
+
if (total + self.data.error_count) > 0 else 0,
|
| 210 |
+
2
|
| 211 |
+
),
|
| 212 |
+
"most_popular_style": popular_style,
|
| 213 |
+
"style_distribution": style_percentages,
|
| 214 |
+
"model_distribution": model_percentages,
|
| 215 |
+
"last_updated": self.data.last_updated
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
def get_display_stats(self) -> str:
|
| 219 |
+
"""
|
| 220 |
+
Get formatted stats for UI display
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
str: Formatted statistics string
|
| 224 |
+
"""
|
| 225 |
+
with self._lock:
|
| 226 |
+
summary = self.get_summary()
|
| 227 |
+
|
| 228 |
+
stats_text = (
|
| 229 |
+
f"📊 Total Captions: {summary['total_captions']} | "
|
| 230 |
+
f"⚡ Avg Time: {summary['avg_processing_time']}s | "
|
| 231 |
+
f"🎨 Popular Style: {summary['most_popular_style']}"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return stats_text
|
| 235 |
+
|
| 236 |
+
def reset_stats(self) -> bool:
|
| 237 |
+
"""
|
| 238 |
+
Reset all statistics
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
bool: True if successful
|
| 242 |
+
"""
|
| 243 |
+
with self._lock:
|
| 244 |
+
self.data = AnalyticsData()
|
| 245 |
+
return self._save_data()
|
| 246 |
+
|
| 247 |
+
def export_stats(self, export_path: Optional[Path] = None) -> bool:
|
| 248 |
+
"""
|
| 249 |
+
Export statistics to a file
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
export_path: Path to export file (default: timestamped file)
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
bool: True if successful
|
| 256 |
+
"""
|
| 257 |
+
with self._lock:
|
| 258 |
+
if export_path is None:
|
| 259 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 260 |
+
export_path = self.storage_path.parent / f"analytics_export_{timestamp}.json"
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
with open(export_path, 'w') as f:
|
| 264 |
+
export_data = {
|
| 265 |
+
"exported_at": datetime.now().isoformat(),
|
| 266 |
+
"statistics": self.data.to_dict(),
|
| 267 |
+
"summary": self.get_summary()
|
| 268 |
+
}
|
| 269 |
+
json.dump(export_data, f, indent=4)
|
| 270 |
+
return True
|
| 271 |
+
except Exception as e:
|
| 272 |
+
print(f"Error exporting analytics: {e}")
|
| 273 |
+
return False
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# Singleton instance
|
| 277 |
+
_analytics_manager = None
|
| 278 |
+
_manager_lock = threading.Lock()
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def get_analytics_manager() -> AnalyticsManager:
|
| 282 |
+
"""Get singleton AnalyticsManager instance"""
|
| 283 |
+
global _analytics_manager
|
| 284 |
+
if _analytics_manager is None:
|
| 285 |
+
with _manager_lock:
|
| 286 |
+
if _analytics_manager is None:
|
| 287 |
+
_analytics_manager = AnalyticsManager()
|
| 288 |
+
return _analytics_manager
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# Convenience functions
|
| 292 |
+
def record_generation(
|
| 293 |
+
model_name: str,
|
| 294 |
+
style: str,
|
| 295 |
+
processing_time: float,
|
| 296 |
+
success: bool = True
|
| 297 |
+
) -> None:
|
| 298 |
+
"""Record a caption generation (convenience function)"""
|
| 299 |
+
get_analytics_manager().record_caption_generation(
|
| 300 |
+
model_name, style, processing_time, success
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def get_stats() -> dict:
|
| 305 |
+
"""Get current statistics (convenience function)"""
|
| 306 |
+
return get_analytics_manager().get_stats()
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def get_summary() -> dict:
|
| 310 |
+
"""Get analytics summary (convenience function)"""
|
| 311 |
+
return get_analytics_manager().get_summary()
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def get_display_stats() -> str:
|
| 315 |
+
"""Get formatted display stats (convenience function)"""
|
| 316 |
+
return get_analytics_manager().get_display_stats()
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
if __name__ == "__main__":
|
| 320 |
+
# Test the analytics manager
|
| 321 |
+
print("=" * 60)
|
| 322 |
+
print("ANALYTICS MANAGER - TEST MODE")
|
| 323 |
+
print("=" * 60)
|
| 324 |
+
|
| 325 |
+
# Initialize manager with test path
|
| 326 |
+
test_path = Path("cache/test_analytics.json")
|
| 327 |
+
analytics = AnalyticsManager(storage_path=test_path)
|
| 328 |
+
|
| 329 |
+
print("\n1. Initial state:")
|
| 330 |
+
print(f" {analytics.get_display_stats()}")
|
| 331 |
+
|
| 332 |
+
print("\n2. Recording test generations:")
|
| 333 |
+
analytics.record_caption_generation("blip", "Professional", 2.5, True)
|
| 334 |
+
analytics.record_caption_generation("git", "Creative", 3.2, True)
|
| 335 |
+
analytics.record_caption_generation("blip", "Professional", 2.1, True)
|
| 336 |
+
analytics.record_caption_generation("git", "Social Media", 2.8, True)
|
| 337 |
+
analytics.record_caption_generation("blip", "Technical", 2.3, False)
|
| 338 |
+
print(f" Recorded 5 generations (4 success, 1 error)")
|
| 339 |
+
|
| 340 |
+
print("\n3. Current statistics:")
|
| 341 |
+
stats = analytics.get_stats()
|
| 342 |
+
for key, value in stats.items():
|
| 343 |
+
if isinstance(value, dict):
|
| 344 |
+
print(f" {key}:")
|
| 345 |
+
for k, v in value.items():
|
| 346 |
+
print(f" {k}: {v}")
|
| 347 |
+
else:
|
| 348 |
+
print(f" {key}: {value}")
|
| 349 |
+
|
| 350 |
+
print("\n4. Summary:")
|
| 351 |
+
summary = analytics.get_summary()
|
| 352 |
+
for key, value in summary.items():
|
| 353 |
+
if isinstance(value, dict):
|
| 354 |
+
print(f" {key}:")
|
| 355 |
+
for k, v in value.items():
|
| 356 |
+
print(f" {k}: {v}")
|
| 357 |
+
else:
|
| 358 |
+
print(f" {key}: {value}")
|
| 359 |
+
|
| 360 |
+
print("\n5. Display format:")
|
| 361 |
+
print(f" {analytics.get_display_stats()}")
|
| 362 |
+
|
| 363 |
+
print("\n6. File saved to:")
|
| 364 |
+
print(f" {test_path}")
|
| 365 |
+
|
| 366 |
+
print("\n" + "=" * 60)
|
| 367 |
+
print("✓ Analytics manager tests complete")
|
| 368 |
+
print("=" * 60)
|
| 369 |
+
|
| 370 |
+
# Cleanup test file
|
| 371 |
+
if test_path.exists():
|
| 372 |
+
test_path.unlink()
|
| 373 |
+
print("\n✓ Test file cleaned up")
|
src/utils/cache_manager.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cache Management Module
|
| 3 |
+
|
| 4 |
+
Implements intelligent caching for caption generation results.
|
| 5 |
+
Uses LRU (Least Recently Used) eviction policy for memory efficiency.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
import json
|
| 10 |
+
from typing import Optional, Any, Dict
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
from dataclasses import dataclass, asdict
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
import threading
|
| 15 |
+
|
| 16 |
+
from config import cache_config
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class CacheEntry:
|
| 21 |
+
"""Represents a single cache entry with metadata"""
|
| 22 |
+
key: str
|
| 23 |
+
value: Any
|
| 24 |
+
timestamp: float
|
| 25 |
+
access_count: int = 0
|
| 26 |
+
last_accessed: float = None
|
| 27 |
+
|
| 28 |
+
def __post_init__(self):
|
| 29 |
+
if self.last_accessed is None:
|
| 30 |
+
self.last_accessed = self.timestamp
|
| 31 |
+
|
| 32 |
+
def to_dict(self) -> dict:
|
| 33 |
+
"""Convert to dictionary"""
|
| 34 |
+
return asdict(self)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class CacheManager:
|
| 38 |
+
"""
|
| 39 |
+
Thread-safe LRU cache manager for caption results
|
| 40 |
+
|
| 41 |
+
Features:
|
| 42 |
+
- Automatic expiration based on TTL
|
| 43 |
+
- LRU eviction when max size reached
|
| 44 |
+
- Thread-safe operations
|
| 45 |
+
- Access statistics
|
| 46 |
+
- Memory-efficient storage
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
max_size: int = cache_config.MAX_CACHE_SIZE,
|
| 52 |
+
ttl_seconds: int = cache_config.CACHE_TTL_SECONDS
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
Initialize cache manager
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
max_size: Maximum number of cached items
|
| 59 |
+
ttl_seconds: Time to live for cache entries
|
| 60 |
+
"""
|
| 61 |
+
self.max_size = max_size
|
| 62 |
+
self.ttl_seconds = ttl_seconds
|
| 63 |
+
|
| 64 |
+
# OrderedDict maintains insertion order and enables O(1) LRU
|
| 65 |
+
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
| 66 |
+
|
| 67 |
+
# Thread safety
|
| 68 |
+
self._lock = threading.RLock()
|
| 69 |
+
|
| 70 |
+
# Statistics
|
| 71 |
+
self._stats = {
|
| 72 |
+
"hits": 0,
|
| 73 |
+
"misses": 0,
|
| 74 |
+
"evictions": 0,
|
| 75 |
+
"expirations": 0,
|
| 76 |
+
"total_sets": 0
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
def get(self, key: str) -> Optional[Any]:
|
| 80 |
+
"""
|
| 81 |
+
Retrieve value from cache
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
key: Cache key
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Optional[Any]: Cached value or None if not found/expired
|
| 88 |
+
"""
|
| 89 |
+
with self._lock:
|
| 90 |
+
if key not in self._cache:
|
| 91 |
+
self._stats["misses"] += 1
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
entry = self._cache[key]
|
| 95 |
+
|
| 96 |
+
# Check if expired
|
| 97 |
+
if self._is_expired(entry):
|
| 98 |
+
self._remove_entry(key)
|
| 99 |
+
self._stats["expirations"] += 1
|
| 100 |
+
self._stats["misses"] += 1
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
# Update access statistics
|
| 104 |
+
entry.access_count += 1
|
| 105 |
+
entry.last_accessed = time.time()
|
| 106 |
+
|
| 107 |
+
# Move to end (most recently used)
|
| 108 |
+
self._cache.move_to_end(key)
|
| 109 |
+
|
| 110 |
+
self._stats["hits"] += 1
|
| 111 |
+
return entry.value
|
| 112 |
+
|
| 113 |
+
def set(self, key: str, value: Any) -> bool:
|
| 114 |
+
"""
|
| 115 |
+
Store value in cache
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
key: Cache key
|
| 119 |
+
value: Value to cache
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
bool: True if successfully cached
|
| 123 |
+
"""
|
| 124 |
+
with self._lock:
|
| 125 |
+
current_time = time.time()
|
| 126 |
+
|
| 127 |
+
# If key exists, update it
|
| 128 |
+
if key in self._cache:
|
| 129 |
+
entry = self._cache[key]
|
| 130 |
+
entry.value = value
|
| 131 |
+
entry.timestamp = current_time
|
| 132 |
+
entry.last_accessed = current_time
|
| 133 |
+
self._cache.move_to_end(key)
|
| 134 |
+
else:
|
| 135 |
+
# Check if we need to evict
|
| 136 |
+
if len(self._cache) >= self.max_size:
|
| 137 |
+
self._evict_oldest()
|
| 138 |
+
|
| 139 |
+
# Add new entry
|
| 140 |
+
entry = CacheEntry(
|
| 141 |
+
key=key,
|
| 142 |
+
value=value,
|
| 143 |
+
timestamp=current_time,
|
| 144 |
+
last_accessed=current_time
|
| 145 |
+
)
|
| 146 |
+
self._cache[key] = entry
|
| 147 |
+
|
| 148 |
+
self._stats["total_sets"] += 1
|
| 149 |
+
return True
|
| 150 |
+
|
| 151 |
+
def delete(self, key: str) -> bool:
|
| 152 |
+
"""
|
| 153 |
+
Remove entry from cache
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
key: Cache key
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
bool: True if entry was deleted
|
| 160 |
+
"""
|
| 161 |
+
with self._lock:
|
| 162 |
+
if key in self._cache:
|
| 163 |
+
del self._cache[key]
|
| 164 |
+
return True
|
| 165 |
+
return False
|
| 166 |
+
|
| 167 |
+
def clear(self) -> None:
|
| 168 |
+
"""Clear all cache entries"""
|
| 169 |
+
with self._lock:
|
| 170 |
+
self._cache.clear()
|
| 171 |
+
|
| 172 |
+
def _is_expired(self, entry: CacheEntry) -> bool:
|
| 173 |
+
"""Check if cache entry has expired"""
|
| 174 |
+
return (time.time() - entry.timestamp) > self.ttl_seconds
|
| 175 |
+
|
| 176 |
+
def _remove_entry(self, key: str) -> None:
|
| 177 |
+
"""Remove entry without stats update"""
|
| 178 |
+
if key in self._cache:
|
| 179 |
+
del self._cache[key]
|
| 180 |
+
|
| 181 |
+
def _evict_oldest(self) -> None:
|
| 182 |
+
"""Evict least recently used entry"""
|
| 183 |
+
if self._cache:
|
| 184 |
+
# OrderedDict: first item is least recently used
|
| 185 |
+
oldest_key = next(iter(self._cache))
|
| 186 |
+
del self._cache[oldest_key]
|
| 187 |
+
self._stats["evictions"] += 1
|
| 188 |
+
|
| 189 |
+
def cleanup_expired(self) -> int:
|
| 190 |
+
"""
|
| 191 |
+
Remove all expired entries
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
int: Number of entries removed
|
| 195 |
+
"""
|
| 196 |
+
with self._lock:
|
| 197 |
+
current_time = time.time()
|
| 198 |
+
expired_keys = [
|
| 199 |
+
key for key, entry in self._cache.items()
|
| 200 |
+
if (current_time - entry.timestamp) > self.ttl_seconds
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
for key in expired_keys:
|
| 204 |
+
del self._cache[key]
|
| 205 |
+
|
| 206 |
+
if expired_keys:
|
| 207 |
+
self._stats["expirations"] += len(expired_keys)
|
| 208 |
+
|
| 209 |
+
return len(expired_keys)
|
| 210 |
+
|
| 211 |
+
def get_stats(self) -> dict:
|
| 212 |
+
"""
|
| 213 |
+
Get cache statistics
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
dict: Cache statistics including hit rate
|
| 217 |
+
"""
|
| 218 |
+
with self._lock:
|
| 219 |
+
total_requests = self._stats["hits"] + self._stats["misses"]
|
| 220 |
+
hit_rate = (
|
| 221 |
+
(self._stats["hits"] / total_requests * 100)
|
| 222 |
+
if total_requests > 0 else 0
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
return {
|
| 226 |
+
**self._stats,
|
| 227 |
+
"size": len(self._cache),
|
| 228 |
+
"max_size": self.max_size,
|
| 229 |
+
"hit_rate": round(hit_rate, 2),
|
| 230 |
+
"total_requests": total_requests
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
def get_info(self) -> dict:
|
| 234 |
+
"""
|
| 235 |
+
Get detailed cache information
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
dict: Detailed cache state
|
| 239 |
+
"""
|
| 240 |
+
with self._lock:
|
| 241 |
+
entries_info = []
|
| 242 |
+
for key, entry in self._cache.items():
|
| 243 |
+
age_seconds = time.time() - entry.timestamp
|
| 244 |
+
entries_info.append({
|
| 245 |
+
"key": key[:50] + "..." if len(key) > 50 else key,
|
| 246 |
+
"age_seconds": round(age_seconds, 2),
|
| 247 |
+
"access_count": entry.access_count,
|
| 248 |
+
"size_estimate": len(str(entry.value))
|
| 249 |
+
})
|
| 250 |
+
|
| 251 |
+
return {
|
| 252 |
+
"stats": self.get_stats(),
|
| 253 |
+
"entries": entries_info[:10], # Show top 10
|
| 254 |
+
"config": {
|
| 255 |
+
"max_size": self.max_size,
|
| 256 |
+
"ttl_seconds": self.ttl_seconds
|
| 257 |
+
}
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class CaptionCache:
|
| 262 |
+
"""
|
| 263 |
+
Specialized cache for image captions
|
| 264 |
+
|
| 265 |
+
Manages caching of caption generation results with image hash keys
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
def __init__(self):
|
| 269 |
+
"""Initialize caption cache"""
|
| 270 |
+
self.cache = CacheManager(
|
| 271 |
+
max_size=cache_config.MAX_CACHE_SIZE,
|
| 272 |
+
ttl_seconds=cache_config.CACHE_TTL_SECONDS
|
| 273 |
+
)
|
| 274 |
+
self.enabled = cache_config.ENABLE_CAPTION_CACHE
|
| 275 |
+
|
| 276 |
+
def get_caption(
|
| 277 |
+
self,
|
| 278 |
+
image_hash: str,
|
| 279 |
+
model_name: str,
|
| 280 |
+
style: str
|
| 281 |
+
) -> Optional[str]:
|
| 282 |
+
"""
|
| 283 |
+
Retrieve cached caption
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
image_hash: Hash of the image
|
| 287 |
+
model_name: Name of the caption model
|
| 288 |
+
style: Style applied
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
Optional[str]: Cached caption or None
|
| 292 |
+
"""
|
| 293 |
+
if not self.enabled:
|
| 294 |
+
return None
|
| 295 |
+
|
| 296 |
+
cache_key = self._generate_key(image_hash, model_name, style)
|
| 297 |
+
return self.cache.get(cache_key)
|
| 298 |
+
|
| 299 |
+
def set_caption(
|
| 300 |
+
self,
|
| 301 |
+
image_hash: str,
|
| 302 |
+
model_name: str,
|
| 303 |
+
style: str,
|
| 304 |
+
caption: str
|
| 305 |
+
) -> bool:
|
| 306 |
+
"""
|
| 307 |
+
Store caption in cache
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
image_hash: Hash of the image
|
| 311 |
+
model_name: Name of the caption model
|
| 312 |
+
style: Style applied
|
| 313 |
+
caption: Generated caption
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
bool: True if successfully cached
|
| 317 |
+
"""
|
| 318 |
+
if not self.enabled:
|
| 319 |
+
return False
|
| 320 |
+
|
| 321 |
+
cache_key = self._generate_key(image_hash, model_name, style)
|
| 322 |
+
return self.cache.set(cache_key, caption)
|
| 323 |
+
|
| 324 |
+
def _generate_key(self, image_hash: str, model_name: str, style: str) -> str:
|
| 325 |
+
"""Generate cache key from components"""
|
| 326 |
+
return f"{image_hash}:{model_name}:{style}"
|
| 327 |
+
|
| 328 |
+
def get_stats(self) -> dict:
|
| 329 |
+
"""Get cache statistics"""
|
| 330 |
+
return self.cache.get_stats()
|
| 331 |
+
|
| 332 |
+
def clear(self) -> None:
|
| 333 |
+
"""Clear all cached captions"""
|
| 334 |
+
self.cache.clear()
|
| 335 |
+
|
| 336 |
+
def cleanup(self) -> int:
|
| 337 |
+
"""Clean up expired entries"""
|
| 338 |
+
return self.cache.cleanup_expired()
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# Singleton instances
|
| 342 |
+
_cache_manager = None
|
| 343 |
+
_caption_cache = None
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def get_cache_manager() -> CacheManager:
|
| 347 |
+
"""Get singleton CacheManager instance"""
|
| 348 |
+
global _cache_manager
|
| 349 |
+
if _cache_manager is None:
|
| 350 |
+
_cache_manager = CacheManager()
|
| 351 |
+
return _cache_manager
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def get_caption_cache() -> CaptionCache:
|
| 355 |
+
"""Get singleton CaptionCache instance"""
|
| 356 |
+
global _caption_cache
|
| 357 |
+
if _caption_cache is None:
|
| 358 |
+
_caption_cache = CaptionCache()
|
| 359 |
+
return _caption_cache
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
if __name__ == "__main__":
|
| 363 |
+
# Test the cache manager
|
| 364 |
+
print("=" * 60)
|
| 365 |
+
print("CACHE MANAGER - TEST MODE")
|
| 366 |
+
print("=" * 60)
|
| 367 |
+
|
| 368 |
+
# Test basic cache operations
|
| 369 |
+
cache = CacheManager(max_size=3, ttl_seconds=5)
|
| 370 |
+
|
| 371 |
+
print("\n1. Testing SET operations:")
|
| 372 |
+
cache.set("key1", "value1")
|
| 373 |
+
cache.set("key2", "value2")
|
| 374 |
+
cache.set("key3", "value3")
|
| 375 |
+
print(f" Added 3 items")
|
| 376 |
+
print(f" Cache size: {len(cache._cache)}")
|
| 377 |
+
|
| 378 |
+
print("\n2. Testing GET operations:")
|
| 379 |
+
result = cache.get("key1")
|
| 380 |
+
print(f" Get 'key1': {result}")
|
| 381 |
+
print(f" Stats: {cache.get_stats()}")
|
| 382 |
+
|
| 383 |
+
print("\n3. Testing LRU eviction:")
|
| 384 |
+
cache.set("key4", "value4") # Should evict key2
|
| 385 |
+
print(f" Added 'key4'")
|
| 386 |
+
print(f" Cache size: {len(cache._cache)}")
|
| 387 |
+
print(f" Keys in cache: {list(cache._cache.keys())}")
|
| 388 |
+
|
| 389 |
+
print("\n4. Testing TTL expiration:")
|
| 390 |
+
print(f" Waiting 6 seconds for expiration...")
|
| 391 |
+
time.sleep(6)
|
| 392 |
+
expired = cache.cleanup_expired()
|
| 393 |
+
print(f" Expired entries: {expired}")
|
| 394 |
+
print(f" Cache size: {len(cache._cache)}")
|
| 395 |
+
|
| 396 |
+
print("\n5. Final stats:")
|
| 397 |
+
stats = cache.get_stats()
|
| 398 |
+
for key, value in stats.items():
|
| 399 |
+
print(f" {key}: {value}")
|
| 400 |
+
|
| 401 |
+
print("\n" + "=" * 60)
|
| 402 |
+
print("✓ Cache manager tests complete")
|
| 403 |
+
print("=" * 60)
|
src/utils/image_processor.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image Processing Module
|
| 3 |
+
|
| 4 |
+
Handles image validation, preprocessing, and optimization for caption generation.
|
| 5 |
+
Ensures images meet model requirements while maintaining quality.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import io
|
| 9 |
+
import hashlib
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Tuple, Union
|
| 12 |
+
from PIL import Image, ImageOps
|
| 13 |
+
|
| 14 |
+
from config import image_config
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ImageProcessingError(Exception):
|
| 18 |
+
"""Custom exception for image processing errors"""
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ImageProcessor:
|
| 23 |
+
"""
|
| 24 |
+
Enterprise-grade image processor for caption generation pipeline
|
| 25 |
+
|
| 26 |
+
Responsibilities:
|
| 27 |
+
- Validate image format and size
|
| 28 |
+
- Resize and optimize images
|
| 29 |
+
- Generate cache keys
|
| 30 |
+
- Handle edge cases and errors gracefully
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self):
|
| 34 |
+
"""Initialize image processor with configuration"""
|
| 35 |
+
self.max_size = image_config.MAX_FILE_SIZE_BYTES
|
| 36 |
+
self.max_dimension = image_config.MAX_DIMENSION
|
| 37 |
+
self.min_dimension = image_config.MIN_DIMENSION
|
| 38 |
+
self.allowed_formats = image_config.ALLOWED_FORMATS
|
| 39 |
+
self.quality = image_config.RESIZE_QUALITY
|
| 40 |
+
|
| 41 |
+
def validate_image(self, image: Union[str, Path, Image.Image, bytes]) -> Tuple[bool, str]:
|
| 42 |
+
"""
|
| 43 |
+
Validate image meets all requirements
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
image: Image path, PIL Image, or bytes
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Tuple[bool, str]: (is_valid, error_message)
|
| 50 |
+
"""
|
| 51 |
+
try:
|
| 52 |
+
# Load image if path or bytes provided
|
| 53 |
+
if isinstance(image, (str, Path)):
|
| 54 |
+
img = Image.open(image)
|
| 55 |
+
elif isinstance(image, bytes):
|
| 56 |
+
img = Image.open(io.BytesIO(image))
|
| 57 |
+
elif isinstance(image, Image.Image):
|
| 58 |
+
img = image
|
| 59 |
+
else:
|
| 60 |
+
return False, f"Unsupported image type: {type(image)}"
|
| 61 |
+
|
| 62 |
+
# Check format (handle None format from Gradio)
|
| 63 |
+
# When Gradio passes PIL images with type="pil", format can be None
|
| 64 |
+
if hasattr(img, 'format') and img.format is not None:
|
| 65 |
+
if img.format.upper() not in [fmt.upper() for fmt in self.allowed_formats]:
|
| 66 |
+
return False, f"Unsupported format: {img.format}. Allowed: {', '.join(self.allowed_formats)}"
|
| 67 |
+
else:
|
| 68 |
+
# Format is None - likely from Gradio's PIL conversion
|
| 69 |
+
# We'll validate by checking if it's a valid PIL image
|
| 70 |
+
print(f"DEBUG: Image format is None (from Gradio), skipping format check")
|
| 71 |
+
|
| 72 |
+
# Check dimensions
|
| 73 |
+
width, height = img.size
|
| 74 |
+
if width < self.min_dimension or height < self.min_dimension:
|
| 75 |
+
return False, f"Image too small. Minimum: {self.min_dimension}x{self.min_dimension}px"
|
| 76 |
+
|
| 77 |
+
if width > 10000 or height > 10000:
|
| 78 |
+
return False, "Image dimensions too large (max: 10000x10000px)"
|
| 79 |
+
|
| 80 |
+
# Check file size (if path provided)
|
| 81 |
+
if isinstance(image, (str, Path)):
|
| 82 |
+
file_size = Path(image).stat().st_size
|
| 83 |
+
if file_size > self.max_size:
|
| 84 |
+
max_mb = self.max_size / (1024 * 1024)
|
| 85 |
+
actual_mb = file_size / (1024 * 1024)
|
| 86 |
+
return False, f"File too large: {actual_mb:.1f}MB (max: {max_mb}MB)"
|
| 87 |
+
|
| 88 |
+
# Try to verify image integrity (skip if format is None)
|
| 89 |
+
if hasattr(img, 'format') and img.format is not None:
|
| 90 |
+
# Create a copy before verify (verify closes the file)
|
| 91 |
+
img_copy = img.copy()
|
| 92 |
+
img_copy.verify()
|
| 93 |
+
|
| 94 |
+
return True, ""
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
return False, f"Image validation failed: {str(e)}"
|
| 98 |
+
|
| 99 |
+
def preprocess_image(
|
| 100 |
+
self,
|
| 101 |
+
image: Union[str, Path, Image.Image, bytes]
|
| 102 |
+
) -> Tuple[Image.Image, dict]:
|
| 103 |
+
"""
|
| 104 |
+
Preprocess image for model input
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
image: Image path, PIL Image, or bytes
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Tuple[Image.Image, dict]: (processed_image, metadata)
|
| 111 |
+
|
| 112 |
+
Raises:
|
| 113 |
+
ImageProcessingError: If preprocessing fails
|
| 114 |
+
"""
|
| 115 |
+
try:
|
| 116 |
+
print(f"DEBUG: Preprocessing image of type: {type(image)}")
|
| 117 |
+
|
| 118 |
+
# Validate first
|
| 119 |
+
is_valid, error_msg = self.validate_image(image)
|
| 120 |
+
if not is_valid:
|
| 121 |
+
print(f"DEBUG: Validation failed: {error_msg}")
|
| 122 |
+
raise ImageProcessingError(error_msg)
|
| 123 |
+
|
| 124 |
+
# Load image
|
| 125 |
+
if isinstance(image, (str, Path)):
|
| 126 |
+
img = Image.open(image)
|
| 127 |
+
elif isinstance(image, bytes):
|
| 128 |
+
img = Image.open(io.BytesIO(image))
|
| 129 |
+
elif isinstance(image, Image.Image):
|
| 130 |
+
img = image.copy() # Don't modify original
|
| 131 |
+
else:
|
| 132 |
+
raise ImageProcessingError(f"Unsupported image type: {type(image)}")
|
| 133 |
+
|
| 134 |
+
# Store original metadata
|
| 135 |
+
original_size = img.size
|
| 136 |
+
original_format = img.format if hasattr(img, 'format') else 'Unknown'
|
| 137 |
+
original_mode = img.mode
|
| 138 |
+
|
| 139 |
+
print(f"DEBUG: Original format: {original_format}, mode: {original_mode}, size: {original_size}")
|
| 140 |
+
|
| 141 |
+
# Convert to RGB if needed (handles RGBA, grayscale, etc.)
|
| 142 |
+
if img.mode != "RGB":
|
| 143 |
+
if img.mode == "RGBA":
|
| 144 |
+
# Create white background for transparent images
|
| 145 |
+
background = Image.new("RGB", img.size, (255, 255, 255))
|
| 146 |
+
background.paste(img, mask=img.split()[-1]) # Use alpha channel as mask
|
| 147 |
+
img = background
|
| 148 |
+
else:
|
| 149 |
+
img = img.convert("RGB")
|
| 150 |
+
|
| 151 |
+
# Auto-orient based on EXIF data
|
| 152 |
+
img = ImageOps.exif_transpose(img)
|
| 153 |
+
|
| 154 |
+
# Resize if needed
|
| 155 |
+
if max(img.size) > self.max_dimension:
|
| 156 |
+
img = self._resize_image(img)
|
| 157 |
+
|
| 158 |
+
# Generate metadata
|
| 159 |
+
metadata = {
|
| 160 |
+
"original_size": original_size,
|
| 161 |
+
"original_format": original_format,
|
| 162 |
+
"original_mode": original_mode,
|
| 163 |
+
"processed_size": img.size,
|
| 164 |
+
"processed_mode": img.mode,
|
| 165 |
+
"was_resized": original_size != img.size,
|
| 166 |
+
"was_converted": original_mode != img.mode
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
print(f"DEBUG: Preprocessing complete. Final size: {img.size}, mode: {img.mode}")
|
| 170 |
+
|
| 171 |
+
return img, metadata
|
| 172 |
+
|
| 173 |
+
except ImageProcessingError:
|
| 174 |
+
raise
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"DEBUG: Exception during preprocessing: {str(e)}")
|
| 177 |
+
raise ImageProcessingError(f"Preprocessing failed: {str(e)}")
|
| 178 |
+
|
| 179 |
+
def _resize_image(self, img: Image.Image) -> Image.Image:
|
| 180 |
+
"""
|
| 181 |
+
Resize image maintaining aspect ratio
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
img: PIL Image
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
Image.Image: Resized image
|
| 188 |
+
"""
|
| 189 |
+
width, height = img.size
|
| 190 |
+
|
| 191 |
+
if image_config.MAINTAIN_ASPECT_RATIO:
|
| 192 |
+
# Calculate new dimensions maintaining aspect ratio
|
| 193 |
+
if width > height:
|
| 194 |
+
new_width = self.max_dimension
|
| 195 |
+
new_height = int((height / width) * self.max_dimension)
|
| 196 |
+
else:
|
| 197 |
+
new_height = self.max_dimension
|
| 198 |
+
new_width = int((width / height) * self.max_dimension)
|
| 199 |
+
else:
|
| 200 |
+
new_width = self.max_dimension
|
| 201 |
+
new_height = self.max_dimension
|
| 202 |
+
|
| 203 |
+
# Use high-quality resampling
|
| 204 |
+
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 205 |
+
|
| 206 |
+
return img
|
| 207 |
+
|
| 208 |
+
def generate_image_hash(
|
| 209 |
+
self,
|
| 210 |
+
image: Union[str, Path, Image.Image, bytes],
|
| 211 |
+
algorithm: str = "md5"
|
| 212 |
+
) -> str:
|
| 213 |
+
"""
|
| 214 |
+
Generate unique hash for image (for caching)
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
image: Image path, PIL Image, or bytes
|
| 218 |
+
algorithm: Hash algorithm (md5, sha256)
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
str: Hexadecimal hash string
|
| 222 |
+
"""
|
| 223 |
+
try:
|
| 224 |
+
# Convert to bytes
|
| 225 |
+
if isinstance(image, (str, Path)):
|
| 226 |
+
with open(image, "rb") as f:
|
| 227 |
+
image_bytes = f.read()
|
| 228 |
+
elif isinstance(image, bytes):
|
| 229 |
+
image_bytes = image
|
| 230 |
+
elif isinstance(image, Image.Image):
|
| 231 |
+
buffer = io.BytesIO()
|
| 232 |
+
image.save(buffer, format="PNG")
|
| 233 |
+
image_bytes = buffer.getvalue()
|
| 234 |
+
else:
|
| 235 |
+
raise ValueError(f"Unsupported type for hashing: {type(image)}")
|
| 236 |
+
|
| 237 |
+
# Generate hash
|
| 238 |
+
if algorithm == "md5":
|
| 239 |
+
return hashlib.md5(image_bytes).hexdigest()
|
| 240 |
+
elif algorithm == "sha256":
|
| 241 |
+
return hashlib.sha256(image_bytes).hexdigest()
|
| 242 |
+
else:
|
| 243 |
+
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
|
| 244 |
+
|
| 245 |
+
except Exception as e:
|
| 246 |
+
raise ImageProcessingError(f"Hash generation failed: {str(e)}")
|
| 247 |
+
|
| 248 |
+
def image_to_bytes(self, img: Image.Image, format: str = "PNG") -> bytes:
|
| 249 |
+
"""
|
| 250 |
+
Convert PIL Image to bytes
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
img: PIL Image
|
| 254 |
+
format: Output format (PNG, JPEG)
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
bytes: Image bytes
|
| 258 |
+
"""
|
| 259 |
+
buffer = io.BytesIO()
|
| 260 |
+
img.save(buffer, format=format, quality=self.quality)
|
| 261 |
+
return buffer.getvalue()
|
| 262 |
+
|
| 263 |
+
def get_image_info(self, image: Union[str, Path, Image.Image]) -> dict:
|
| 264 |
+
"""
|
| 265 |
+
Get detailed image information
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
image: Image path or PIL Image
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
dict: Image information
|
| 272 |
+
"""
|
| 273 |
+
try:
|
| 274 |
+
if isinstance(image, (str, Path)):
|
| 275 |
+
img = Image.open(image)
|
| 276 |
+
file_size = Path(image).stat().st_size
|
| 277 |
+
elif isinstance(image, Image.Image):
|
| 278 |
+
img = image
|
| 279 |
+
file_size = len(self.image_to_bytes(img))
|
| 280 |
+
else:
|
| 281 |
+
raise ValueError(f"Unsupported type: {type(image)}")
|
| 282 |
+
|
| 283 |
+
return {
|
| 284 |
+
"format": img.format,
|
| 285 |
+
"mode": img.mode,
|
| 286 |
+
"size": img.size,
|
| 287 |
+
"width": img.size[0],
|
| 288 |
+
"height": img.size[1],
|
| 289 |
+
"file_size": file_size,
|
| 290 |
+
"file_size_mb": file_size / (1024 * 1024),
|
| 291 |
+
"aspect_ratio": img.size[0] / img.size[1],
|
| 292 |
+
"megapixels": (img.size[0] * img.size[1]) / 1_000_000
|
| 293 |
+
}
|
| 294 |
+
except Exception as e:
|
| 295 |
+
raise ImageProcessingError(f"Failed to get image info: {str(e)}")
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# ============================================================================
|
| 299 |
+
# SINGLETON INSTANCE AND CONVENIENCE FUNCTIONS
|
| 300 |
+
# ============================================================================
|
| 301 |
+
|
| 302 |
+
_image_processor = None
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def get_image_processor() -> ImageProcessor:
|
| 306 |
+
"""Get singleton ImageProcessor instance"""
|
| 307 |
+
global _image_processor
|
| 308 |
+
if _image_processor is None:
|
| 309 |
+
_image_processor = ImageProcessor()
|
| 310 |
+
return _image_processor
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# Convenience wrapper functions for backward compatibility
|
| 314 |
+
def validate_image(image: Union[str, Path, Image.Image, bytes]) -> Tuple[bool, str]:
|
| 315 |
+
"""
|
| 316 |
+
Convenience function: Validate image using singleton processor
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
image: Image path, PIL Image, or bytes
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
Tuple[bool, str]: (is_valid, error_message)
|
| 323 |
+
"""
|
| 324 |
+
return get_image_processor().validate_image(image)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def preprocess_image(
|
| 328 |
+
image: Union[str, Path, Image.Image, bytes]
|
| 329 |
+
) -> Tuple[Image.Image, dict]:
|
| 330 |
+
"""
|
| 331 |
+
Convenience function: Preprocess image using singleton processor
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
image: Image path, PIL Image, or bytes
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
Tuple[Image.Image, dict]: (processed_image, metadata)
|
| 338 |
+
"""
|
| 339 |
+
return get_image_processor().preprocess_image(image)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def generate_image_hash(
|
| 343 |
+
image: Union[str, Path, Image.Image, bytes],
|
| 344 |
+
algorithm: str = "md5"
|
| 345 |
+
) -> str:
|
| 346 |
+
"""
|
| 347 |
+
Convenience function: Generate image hash using singleton processor
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
image: Image path, PIL Image, or bytes
|
| 351 |
+
algorithm: Hash algorithm (md5, sha256)
|
| 352 |
+
|
| 353 |
+
Returns:
|
| 354 |
+
str: Hexadecimal hash string
|
| 355 |
+
"""
|
| 356 |
+
return get_image_processor().generate_image_hash(image, algorithm)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
if __name__ == "__main__":
|
| 360 |
+
# Test the image processor
|
| 361 |
+
print("=" * 60)
|
| 362 |
+
print("IMAGE PROCESSOR - TEST MODE")
|
| 363 |
+
print("=" * 60)
|
| 364 |
+
|
| 365 |
+
processor = get_image_processor()
|
| 366 |
+
print(f"✓ ImageProcessor initialized")
|
| 367 |
+
print(f" - Max file size: {processor.max_size / (1024*1024):.1f}MB")
|
| 368 |
+
print(f" - Max dimension: {processor.max_dimension}px")
|
| 369 |
+
print(f" - Allowed formats: {', '.join(processor.allowed_formats)}")
|
| 370 |
+
print(f" - Quality: {processor.quality}")
|
| 371 |
+
print("=" * 60)
|
| 372 |
+
print("Ready for testing with actual images")
|
| 373 |
+
print("=" * 60)
|