Spaces:
Configuration error
Configuration error
Upload 23 files
Browse files- .gitattributes +1 -0
- .gitignore +176 -0
- Dockerfile +46 -0
- FT.yaml +66 -0
- README.md +343 -10
- app.py +180 -0
- blip2_vicuna_instruct.py +747 -0
- emo/all.py +138 -0
- emo/cap-anno.py +37 -0
- emo/caption.py +26 -0
- emo/desktop.ini +2 -0
- emo/gpt4_conversation.py +87 -0
- emo/gpt4_reasoning.py +89 -0
- emo/prompt/conversation.txt +16 -0
- emo/prompt/description.txt +8 -0
- emo/prompt/reasoning.txt +49 -0
- emo/test.json +0 -0
- emo/train.json +3 -0
- emo/val.json +0 -0
- requirements.txt +39 -0
- requirements_emo.txt +166 -0
- requirements_lavis.txt +158 -0
- static/css/style.css +278 -0
- templates/index.html +253 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
emo/train.json filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
target/
|
| 76 |
+
|
| 77 |
+
# Jupyter Notebook
|
| 78 |
+
.ipynb_checkpoints
|
| 79 |
+
|
| 80 |
+
# IPython
|
| 81 |
+
profile_default/
|
| 82 |
+
ipython_config.py
|
| 83 |
+
|
| 84 |
+
# pyenv
|
| 85 |
+
.python-version
|
| 86 |
+
|
| 87 |
+
# pipenv
|
| 88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 91 |
+
# install all needed dependencies.
|
| 92 |
+
#Pipfile.lock
|
| 93 |
+
|
| 94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 95 |
+
__pypackages__/
|
| 96 |
+
|
| 97 |
+
# Celery stuff
|
| 98 |
+
celerybeat-schedule
|
| 99 |
+
celerybeat.pid
|
| 100 |
+
|
| 101 |
+
# SageMath parsed files
|
| 102 |
+
*.sage.py
|
| 103 |
+
|
| 104 |
+
# Environments
|
| 105 |
+
.env
|
| 106 |
+
.venv
|
| 107 |
+
env/
|
| 108 |
+
venv/
|
| 109 |
+
ENV/
|
| 110 |
+
env.bak/
|
| 111 |
+
venv.bak/
|
| 112 |
+
|
| 113 |
+
# Spyder project settings
|
| 114 |
+
.spyderproject
|
| 115 |
+
.spyproject
|
| 116 |
+
|
| 117 |
+
# Rope project settings
|
| 118 |
+
.ropeproject
|
| 119 |
+
|
| 120 |
+
# mkdocs documentation
|
| 121 |
+
/site
|
| 122 |
+
|
| 123 |
+
# mypy
|
| 124 |
+
.mypy_cache/
|
| 125 |
+
.dmypy.json
|
| 126 |
+
dmypy.json
|
| 127 |
+
|
| 128 |
+
# Pyre type checker
|
| 129 |
+
.pyre/
|
| 130 |
+
|
| 131 |
+
# Model files and checkpoints
|
| 132 |
+
*.pth
|
| 133 |
+
*.pt
|
| 134 |
+
*.bin
|
| 135 |
+
*.safetensors
|
| 136 |
+
checkpoints/
|
| 137 |
+
models/
|
| 138 |
+
|
| 139 |
+
# Data files
|
| 140 |
+
*.csv
|
| 141 |
+
*.json
|
| 142 |
+
*.txt
|
| 143 |
+
!requirements*.txt
|
| 144 |
+
!README.txt
|
| 145 |
+
|
| 146 |
+
# Uploaded files
|
| 147 |
+
uploads/
|
| 148 |
+
temp/
|
| 149 |
+
|
| 150 |
+
# IDE files
|
| 151 |
+
.vscode/
|
| 152 |
+
.idea/
|
| 153 |
+
*.swp
|
| 154 |
+
*.swo
|
| 155 |
+
*~
|
| 156 |
+
|
| 157 |
+
# OS files
|
| 158 |
+
.DS_Store
|
| 159 |
+
.DS_Store?
|
| 160 |
+
._*
|
| 161 |
+
.Spotlight-V100
|
| 162 |
+
.Trashes
|
| 163 |
+
ehthumbs.db
|
| 164 |
+
Thumbs.db
|
| 165 |
+
|
| 166 |
+
# Logs
|
| 167 |
+
logs/
|
| 168 |
+
*.log
|
| 169 |
+
|
| 170 |
+
# Temporary files
|
| 171 |
+
tmp/
|
| 172 |
+
temp/
|
| 173 |
+
*.tmp
|
| 174 |
+
|
| 175 |
+
# Hugging Face cache
|
| 176 |
+
.cache/
|
Dockerfile
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9-slim
|
| 2 |
+
|
| 3 |
+
# Set working directory
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Install system dependencies
|
| 7 |
+
RUN apt-get update && apt-get install -y \
|
| 8 |
+
git \
|
| 9 |
+
wget \
|
| 10 |
+
curl \
|
| 11 |
+
build-essential \
|
| 12 |
+
libgl1-mesa-glx \
|
| 13 |
+
libglib2.0-0 \
|
| 14 |
+
libsm6 \
|
| 15 |
+
libxext6 \
|
| 16 |
+
libxrender-dev \
|
| 17 |
+
libgomp1 \
|
| 18 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 19 |
+
|
| 20 |
+
# Copy requirements first to leverage Docker cache
|
| 21 |
+
COPY requirements.txt .
|
| 22 |
+
|
| 23 |
+
# Install Python dependencies
|
| 24 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 25 |
+
pip install --no-cache-dir -r requirements.txt
|
| 26 |
+
|
| 27 |
+
# Copy application code
|
| 28 |
+
COPY . .
|
| 29 |
+
|
| 30 |
+
# Create necessary directories
|
| 31 |
+
RUN mkdir -p static/css templates
|
| 32 |
+
|
| 33 |
+
# Set environment variables for Hugging Face Spaces
|
| 34 |
+
ENV PYTHONPATH=/app
|
| 35 |
+
ENV FLASK_APP=app.py
|
| 36 |
+
ENV FLASK_ENV=production
|
| 37 |
+
|
| 38 |
+
# Expose port for Hugging Face Spaces
|
| 39 |
+
EXPOSE 7860
|
| 40 |
+
|
| 41 |
+
# Health check
|
| 42 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
|
| 43 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 44 |
+
|
| 45 |
+
# Command to run the application
|
| 46 |
+
CMD ["python", "app.py"]
|
FT.yaml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 5 |
+
|
| 6 |
+
model:
|
| 7 |
+
arch: blip2_opt
|
| 8 |
+
model_type: caption_coco_opt2.7b
|
| 9 |
+
load_finetuned: False
|
| 10 |
+
use_grad_checkpoint: True
|
| 11 |
+
#freeze_vit: False
|
| 12 |
+
freeze_vit: True
|
| 13 |
+
|
| 14 |
+
datasets:
|
| 15 |
+
coco_vqa: # name of the dataset builder
|
| 16 |
+
vis_processor:
|
| 17 |
+
train:
|
| 18 |
+
name: "blip2_image_train"
|
| 19 |
+
image_size: 224
|
| 20 |
+
eval:
|
| 21 |
+
name: "blip_image_eval"
|
| 22 |
+
image_size: 224
|
| 23 |
+
text_processor:
|
| 24 |
+
train:
|
| 25 |
+
name: "blip_caption"
|
| 26 |
+
prompt: " "
|
| 27 |
+
eval:
|
| 28 |
+
name: "blip_caption"
|
| 29 |
+
# build_info:
|
| 30 |
+
# images:
|
| 31 |
+
# storage: '/export/share/datasets/vision/coco/images/'
|
| 32 |
+
|
| 33 |
+
run:
|
| 34 |
+
task: vqa
|
| 35 |
+
# optimizer
|
| 36 |
+
lr_sched: "linear_warmup_cosine_lr"
|
| 37 |
+
init_lr: 1e-5 #8e-6 #1e-5
|
| 38 |
+
min_lr: 1e-8
|
| 39 |
+
warmup_lr: 1e-8
|
| 40 |
+
warmup_steps: 30000
|
| 41 |
+
weight_decay: 0.005 #0.05 #0.00005
|
| 42 |
+
max_epoch: 4
|
| 43 |
+
batch_size_train: 1
|
| 44 |
+
batch_size_eval: 1
|
| 45 |
+
num_workers: 4
|
| 46 |
+
accum_grad_iters: 1
|
| 47 |
+
|
| 48 |
+
max_len: 1000
|
| 49 |
+
min_len: 1 #8
|
| 50 |
+
num_beams: 1 #5 #3
|
| 51 |
+
|
| 52 |
+
seed: 42
|
| 53 |
+
output_dir: "output"
|
| 54 |
+
|
| 55 |
+
amp: True
|
| 56 |
+
resume_ckpt_path: null
|
| 57 |
+
|
| 58 |
+
evaluate: False
|
| 59 |
+
train_splits: ["train"]
|
| 60 |
+
valid_splits: ["val"]
|
| 61 |
+
test_splits: ["test"]
|
| 62 |
+
|
| 63 |
+
device: "cuda"
|
| 64 |
+
world_size: 1
|
| 65 |
+
dist_url: "env://"
|
| 66 |
+
distributed: True
|
README.md
CHANGED
|
@@ -1,10 +1,343 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: EmoVIT
|
| 3 |
-
emoji: 😻
|
| 4 |
-
colorFrom: gray
|
| 5 |
-
colorTo: purple
|
| 6 |
-
sdk: docker
|
| 7 |
-
pinned: false
|
| 8 |
-
---
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: EmoVIT
|
| 3 |
+
emoji: 😻
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# EmoVIT - Emotion Detection with BLIP2-Vicuna
|
| 11 |
+
|
| 12 |
+
🚀 **AI-Powered Emotion Detection Web Application**
|
| 13 |
+
|
| 14 |
+
EmoVIT is a sophisticated emotion detection application that leverages the power of BLIP2-Vicuna model to analyze emotions in images through natural language understanding.
|
| 15 |
+
|
| 16 |
+
## 🌟 Features
|
| 17 |
+
|
| 18 |
+
- **🖼️ Image Upload**: Easy drag-and-drop or click-to-upload interface
|
| 19 |
+
- **🧠 AI Analysis**: Advanced emotion detection using BLIP2-Vicuna model
|
| 20 |
+
- **💬 Custom Prompts**: Personalize your analysis with custom text prompts
|
| 21 |
+
- **🎨 Beautiful UI**: Modern, responsive design with smooth animations
|
| 22 |
+
- **⚡ Real-time Processing**: Fast inference with optimized model loading
|
| 23 |
+
- **📱 Mobile Friendly**: Works seamlessly on all devices
|
| 24 |
+
|
| 25 |
+
## 🛠️ Technology Stack
|
| 26 |
+
|
| 27 |
+
- **Backend**: Flask (Python web framework)
|
| 28 |
+
- **AI Model**: BLIP2-Vicuna (Vision-Language model)
|
| 29 |
+
- **Frontend**: HTML5, CSS3, JavaScript, Bootstrap 5
|
| 30 |
+
- **Deployment**: Docker + Hugging Face Spaces
|
| 31 |
+
|
| 32 |
+
## 🚀 Quick Start
|
| 33 |
+
|
| 34 |
+
### Local Development
|
| 35 |
+
|
| 36 |
+
1. **Clone the repository**
|
| 37 |
+
```bash
|
| 38 |
+
git clone <your-repo-url>
|
| 39 |
+
cd EmoVIT
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
2. **Install dependencies**
|
| 43 |
+
```bash
|
| 44 |
+
pip install -r requirements.txt
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
3. **Run the application**
|
| 48 |
+
```bash
|
| 49 |
+
python app.py
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
4. **Open in browser**
|
| 53 |
+
Navigate to `http://localhost:7860`
|
| 54 |
+
|
| 55 |
+
### Docker Deployment
|
| 56 |
+
|
| 57 |
+
1. **Build the Docker image**
|
| 58 |
+
```bash
|
| 59 |
+
docker build -t emovit .
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
2. **Run the container**
|
| 63 |
+
```bash
|
| 64 |
+
docker run -p 7860:7860 emovit
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
## 🌐 Hugging Face Spaces Deployment
|
| 68 |
+
|
| 69 |
+
This application is configured for seamless deployment on Hugging Face Spaces:
|
| 70 |
+
|
| 71 |
+
1. **Create a new Space** on [Hugging Face Spaces](https://huggingface.co/spaces)
|
| 72 |
+
2. **Select Docker** as the SDK
|
| 73 |
+
3. **Upload your files** to the Space repository
|
| 74 |
+
4. **The app will automatically deploy** using the provided Dockerfile
|
| 75 |
+
|
| 76 |
+
### Required Files for HF Spaces:
|
| 77 |
+
- `app.py` - Main Flask application
|
| 78 |
+
- `Dockerfile` - Container configuration
|
| 79 |
+
- `requirements.txt` - Python dependencies
|
| 80 |
+
- `templates/` - HTML templates
|
| 81 |
+
- `static/` - CSS and static assets
|
| 82 |
+
- `blip2_vicuna_instruct.py` - Model implementation
|
| 83 |
+
|
| 84 |
+
## 📁 Project Structure
|
| 85 |
+
|
| 86 |
+
```
|
| 87 |
+
EmoVIT/
|
| 88 |
+
├── app.py # Main Flask application
|
| 89 |
+
├── blip2_vicuna_instruct.py # BLIP2-Vicuna model implementation
|
| 90 |
+
├── requirements.txt # Python dependencies
|
| 91 |
+
├── Dockerfile # Docker configuration
|
| 92 |
+
├── README.md # This file
|
| 93 |
+
├── templates/
|
| 94 |
+
│ └── index.html # Main HTML template
|
| 95 |
+
├── static/
|
| 96 |
+
│ └── css/
|
| 97 |
+
│ └── style.css # Custom CSS styles
|
| 98 |
+
└── emo/ # Emotion datasets and utilities
|
| 99 |
+
├── train.json
|
| 100 |
+
├── val.json
|
| 101 |
+
└── test.json
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
## 🎯 How It Works
|
| 105 |
+
|
| 106 |
+
1. **Upload Image**: Users upload an image through the web interface
|
| 107 |
+
2. **Enter Prompt**: Optionally customize the analysis prompt
|
| 108 |
+
3. **AI Processing**: The BLIP2-Vicuna model processes the image and prompt
|
| 109 |
+
4. **Results Display**: Emotion analysis results are displayed with the original image
|
| 110 |
+
|
| 111 |
+
## 🔧 Configuration
|
| 112 |
+
|
| 113 |
+
### Model Configuration
|
| 114 |
+
The model can be configured in `app.py`:
|
| 115 |
+
|
| 116 |
+
```python
|
| 117 |
+
model_config = {
|
| 118 |
+
"vit_model": "eva_clip_g",
|
| 119 |
+
"img_size": 224,
|
| 120 |
+
"num_query_token": 32,
|
| 121 |
+
"llm_model": "vicuna-7b-v1.1",
|
| 122 |
+
"max_txt_len": 128,
|
| 123 |
+
"max_output_txt_len": 256,
|
| 124 |
+
# ... other configurations
|
| 125 |
+
}
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### Environment Variables
|
| 129 |
+
- `PORT`: Application port (default: 7860)
|
| 130 |
+
- `FLASK_ENV`: Flask environment (production/development)
|
| 131 |
+
|
| 132 |
+
## 🤖 Model Details
|
| 133 |
+
|
| 134 |
+
**BLIP2-Vicuna** combines:
|
| 135 |
+
- **Vision Encoder**: EVA-CLIP for image understanding
|
| 136 |
+
- **Q-Former**: Querying transformer for cross-modal alignment
|
| 137 |
+
- **Language Model**: Vicuna (LLaMA-based) for text generation
|
| 138 |
+
|
| 139 |
+
This architecture enables sophisticated vision-language understanding for emotion detection tasks.
|
| 140 |
+
|
| 141 |
+
## 📊 Performance & Optimization
|
| 142 |
+
|
| 143 |
+
- **GPU Support**: Automatic CUDA detection and utilization
|
| 144 |
+
- **Memory Efficient**: Optimized model loading and inference
|
| 145 |
+
- **Caching**: Smart caching for improved response times
|
| 146 |
+
- **Error Handling**: Robust error handling and user feedback
|
| 147 |
+
|
| 148 |
+
## 🎨 UI/UX Features
|
| 149 |
+
|
| 150 |
+
- **Responsive Design**: Works on desktop, tablet, and mobile
|
| 151 |
+
- **Modern Aesthetics**: Clean, professional interface
|
| 152 |
+
- **Smooth Animations**: Engaging user interactions
|
| 153 |
+
- **Loading States**: Clear feedback during processing
|
| 154 |
+
- **Error Handling**: User-friendly error messages
|
| 155 |
+
|
| 156 |
+
## 🔒 Security Features
|
| 157 |
+
|
| 158 |
+
- **File Size Limits**: 16MB maximum upload size
|
| 159 |
+
- **File Type Validation**: Only image files accepted
|
| 160 |
+
- **Input Sanitization**: Secure handling of user inputs
|
| 161 |
+
- **CORS Protection**: Appropriate cross-origin policies
|
| 162 |
+
|
| 163 |
+
## 🚀 Deployment Options
|
| 164 |
+
|
| 165 |
+
### 1. Hugging Face Spaces (Recommended)
|
| 166 |
+
- Zero-configuration deployment
|
| 167 |
+
- Automatic scaling
|
| 168 |
+
- Free tier available
|
| 169 |
+
- Built-in GPU support
|
| 170 |
+
|
| 171 |
+
### 2. Docker
|
| 172 |
+
- Consistent environments
|
| 173 |
+
- Easy scaling
|
| 174 |
+
- Platform independent
|
| 175 |
+
|
| 176 |
+
### 3. Local Development
|
| 177 |
+
- Quick testing
|
| 178 |
+
- Development workflow
|
| 179 |
+
- Custom configurations
|
| 180 |
+
|
| 181 |
+
## 🛠️ Development
|
| 182 |
+
|
| 183 |
+
### Adding New Features
|
| 184 |
+
1. Update `app.py` for backend changes
|
| 185 |
+
2. Modify `templates/index.html` for UI changes
|
| 186 |
+
3. Update `static/css/style.css` for styling
|
| 187 |
+
4. Test locally before deployment
|
| 188 |
+
|
| 189 |
+
### Model Updates
|
| 190 |
+
1. Update `blip2_vicuna_instruct.py`
|
| 191 |
+
2. Adjust configuration in `app.py`
|
| 192 |
+
3. Update requirements if needed
|
| 193 |
+
|
| 194 |
+
## 📄 License
|
| 195 |
+
|
| 196 |
+
This project is open-source and available under the MIT License.
|
| 197 |
+
|
| 198 |
+
## 🤝 Contributing
|
| 199 |
+
|
| 200 |
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
| 201 |
+
|
| 202 |
+
## 📞 Support
|
| 203 |
+
|
| 204 |
+
For questions or support, please open an issue in the repository.
|
| 205 |
+
|
| 206 |
+
---
|
| 207 |
+
|
| 208 |
+
**Built with ❤️ using BLIP2-Vicuna and modern web technologies**
|
| 209 |
+
Official code for the paper **"EmoVIT: Revolutionizing Emotion Insights with Visual Instruction Tuning"** | CVPR 2024
|
| 210 |
+
|
| 211 |
+
## 🔄 Update Log – 2025/04/07
|
| 212 |
+
|
| 213 |
+
### 📄 Dataset Update
|
| 214 |
+
|
| 215 |
+
- The originally provided `train.json` was incomplete.
|
| 216 |
+
✅ The latest version now contains the full dataset and can be downloaded here:
|
| 217 |
+
[📎 Download `train.json`](https://drive.google.com/file/d/1OV3X7BJyEDYXTGaDbu7E8rGgGzIlnwVq/view?usp=drive_link)
|
| 218 |
+
|
| 219 |
+
### ✅ Bug Fixes & Configuration Updates
|
| 220 |
+
|
| 221 |
+
- Fixed incorrect version parameters previously used during inference.
|
| 222 |
+
- Updated several related parameter files—please **replace** the original files with the latest versions found in the root directory.
|
| 223 |
+
|
| 224 |
+
### 🛠️ Configuration Changes
|
| 225 |
+
|
| 226 |
+
- `FT.yaml` and `blip2_vicuna_instruct` have been modified to incorporate the correct parameter settings.
|
| 227 |
+
- 📁 **Note:** `blip2_vicuna_instruct` should be placed in:
|
| 228 |
+
`LAVIS/lavis/models/blip2_models/`
|
| 229 |
+
|
| 230 |
+
### 💾 Trained Weights
|
| 231 |
+
|
| 232 |
+
- Weights trained using the **corrected parameters** are now available.
|
| 233 |
+
[📥 Download Trained Weights](https://drive.google.com/file/d/1zaYOSlt3mLVMdiNfAKdJcwvVc-4LHfdr/view?usp=drive_link)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
## 🧠 Emotion Reasoning Support
|
| 237 |
+
|
| 238 |
+
To enable **emotion reasoning output** from the model, format the input prompt as:
|
| 239 |
+
|
| 240 |
+
`Predicted emotion: [emotion]. Reason: [explanation].`
|
| 241 |
+
|
| 242 |
+
🔍 This usage is described in **Section 4.4.1 (Affective Reasoning)** of our paper.
|
| 243 |
+
|
| 244 |
+
---
|
| 245 |
+
|
| 246 |
+
## Setting up the environment
|
| 247 |
+
|
| 248 |
+
```bash
|
| 249 |
+
git clone https://github.com/aimmemotion/EmoVIT.git
|
| 250 |
+
conda create --name emovit python=3.8
|
| 251 |
+
conda activate emovit
|
| 252 |
+
|
| 253 |
+
cd Emovit
|
| 254 |
+
pip install -r requirements_lavis.txt
|
| 255 |
+
```
|
| 256 |
+
## Install the corresponding version of PyTorch
|
| 257 |
+
|
| 258 |
+
```bash
|
| 259 |
+
#Using CUDA 11.8 as an example
|
| 260 |
+
pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu118
|
| 261 |
+
```
|
| 262 |
+
|
| 263 |
+
## Install LAVIS
|
| 264 |
+
|
| 265 |
+
```bash
|
| 266 |
+
pip install salesforce-lavis
|
| 267 |
+
# If not work, please proceed as follows.
|
| 268 |
+
cd ..
|
| 269 |
+
git clone https://github.com/salesforce/LAVIS.git
|
| 270 |
+
cd LAVIS
|
| 271 |
+
pip install -e . # Please remove 'open3d' from the 'requirements.txt' file to avoid version conflicts.
|
| 272 |
+
cd ../
|
| 273 |
+
# Cut the 'lavis' folder and paste it into the 'lib' folder.
|
| 274 |
+
```
|
| 275 |
+
|
| 276 |
+
## Dataset Preparation
|
| 277 |
+
|
| 278 |
+
Download EmoSet from
|
| 279 |
+
https://vcc.tech/EmoSet
|
| 280 |
+
|
| 281 |
+
Extract the downloaded EmoSet files
|
| 282 |
+
(annotation, image, info.json, test.json, train.json, val.json)
|
| 283 |
+
and place them into the emo folder.
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
## Model Preparation
|
| 287 |
+
|
| 288 |
+
Download lavis_with_weight.zip https://drive.google.com/file/d/1vZa7C6rxxsq51VQ73ESGQ0S8zEI2dnq_/view?usp=drive_link
|
| 289 |
+
(If you prefer to train it yourself, you can download lavis_without_weight.zip instead https://drive.google.com/file/d/1Re_lzyrQehuL1SjP4GmgPCMPf5jHg3hs/view?usp=drive_link)
|
| 290 |
+
Extract the zip file and place it in the emovit folder.
|
| 291 |
+
|
| 292 |
+
Download all files from this Hugging Face page
|
| 293 |
+
https://huggingface.co/lmsys/vicuna-7b-v1.1/tree/main
|
| 294 |
+
Place the downloaded files into ./Emovit/LAVIS/lavis/weight/vicuna-7b-2/
|
| 295 |
+
|
| 296 |
+
## Emotion Instruction Data Generation
|
| 297 |
+
|
| 298 |
+
1. Run `python ./emo/caption.py` to obtain image captions. Select the 'path' based on the class to be processed.
|
| 299 |
+
2. Run `python ./emo/cap-anno.py` to write the attributes and captions of the image into a file. Select the 'path' based on the class to be processed.
|
| 300 |
+
3. Run `python ./emo/gpt4_reasoning.py` or `python ./emo/gpt4_conversation.py` to instruct GPT-4 to generate questions using the above file as input data.
|
| 301 |
+
- Remember to change the key.
|
| 302 |
+
- If you wish to adjust the prompt, you can go to the 'prompt' folder.
|
| 303 |
+
4. Run `python ./emo/all.py` to integrate the results of reasoning, conversation, and classification.
|
| 304 |
+
|
| 305 |
+
Following these steps, you can create instructions. If you want to skip this step, you can use the instructions we created using EmoSet. (However, image data must still be downloaded from EmoSet's official website.)
|
| 306 |
+
|
| 307 |
+
- Conversation: [Download](https://drive.google.com/file/d/1E8UEH09y0CiAT4Hg7rm975AR3JCjEHeM/view?usp=drive_link)
|
| 308 |
+
- Reasoning: [Download](https://drive.google.com/file/d/1MTNHFzasCb0F921P0itaH-x8vN2OvxEu/view?usp=drive_link)
|
| 309 |
+
|
| 310 |
+
The generation method of categorical data does not need to rely on GPT for creation; it can be directly produced (you can observe the prompt in `all.py`).
|
| 311 |
+
|
| 312 |
+
#### Training
|
| 313 |
+
|
| 314 |
+
```bash
|
| 315 |
+
cd LAVIS
|
| 316 |
+
python train.py --cfg-path FT.yaml
|
| 317 |
+
```
|
| 318 |
+
|
| 319 |
+
### Parameter Settings
|
| 320 |
+
|
| 321 |
+
- `LAVIS/FT.yaml`: Setting of hyperparameters
|
| 322 |
+
- `LAVIS/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml`: Select the location of LLM weight
|
| 323 |
+
- `LAVIS/lavis/configs/datasets/coco/defaults_vqa.yaml`: Select the location of your data
|
| 324 |
+
LAVIS/lavis/runners/runner_base.py (Change the name of the weight file to be saved.)
|
| 325 |
+
|
| 326 |
+
## Inference EmoVIT
|
| 327 |
+
If you haven't trained your own weights yet, you can use the `model_weights1.pth` provided in the `LAVIS` folder.
|
| 328 |
+
```bash
|
| 329 |
+
python ./LAVIS/test.py
|
| 330 |
+
```
|
| 331 |
+
|
| 332 |
+
## Citation
|
| 333 |
+
|
| 334 |
+
If you found this paper is helpful, please consider cite our paper:
|
| 335 |
+
|
| 336 |
+
```bibtex
|
| 337 |
+
@inproceedings{Xie2024EmoVIT,
|
| 338 |
+
title={EmoVIT: Revolutionizing Emotion Insights with Visual Instruction Tuning},
|
| 339 |
+
author={Hongxia Xie and Chu-Jun Peng and Yu-Wen Tseng and Hung-Jen Chen and Chan-Feng Hsu and Hong-Han Shuai and Wen-Huang Cheng},
|
| 340 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
| 341 |
+
year={2024}
|
| 342 |
+
}
|
| 343 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import io
|
| 3 |
+
import torch
|
| 4 |
+
from flask import Flask, render_template, request, jsonify, url_for
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import base64
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
# Import model từ file hiện tại
|
| 11 |
+
from blip2_vicuna_instruct import Blip2VicunaInstruct
|
| 12 |
+
|
| 13 |
+
app = Flask(__name__)
|
| 14 |
+
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
|
| 15 |
+
|
| 16 |
+
# Global variables cho model
|
| 17 |
+
model = None
|
| 18 |
+
device = None
|
| 19 |
+
|
| 20 |
+
def load_model():
|
| 21 |
+
"""Load BLIP2 Vicuna model"""
|
| 22 |
+
global model, device
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 26 |
+
logging.info(f"Using device: {device}")
|
| 27 |
+
|
| 28 |
+
# Cấu hình model - có thể cần điều chỉnh theo config thực tế
|
| 29 |
+
model_config = {
|
| 30 |
+
"vit_model": "eva_clip_g",
|
| 31 |
+
"img_size": 224,
|
| 32 |
+
"drop_path_rate": 0,
|
| 33 |
+
"use_grad_checkpoint": False,
|
| 34 |
+
"vit_precision": "fp16",
|
| 35 |
+
"freeze_vit": True,
|
| 36 |
+
"num_query_token": 32,
|
| 37 |
+
"llm_model": "vicuna-7b-v1.1", # Có thể cần thay đổi path
|
| 38 |
+
"prompt": "",
|
| 39 |
+
"max_txt_len": 128,
|
| 40 |
+
"max_output_txt_len": 256,
|
| 41 |
+
"apply_lemmatizer": False,
|
| 42 |
+
"qformer_text_input": True,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
# Khởi tạo model
|
| 46 |
+
model = Blip2VicunaInstruct(**model_config)
|
| 47 |
+
model.to(device)
|
| 48 |
+
model.eval()
|
| 49 |
+
|
| 50 |
+
logging.info("Model loaded successfully!")
|
| 51 |
+
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logging.error(f"Error loading model: {str(e)}")
|
| 54 |
+
model = None
|
| 55 |
+
|
| 56 |
+
def preprocess_image(image):
|
| 57 |
+
"""Preprocess image for model"""
|
| 58 |
+
try:
|
| 59 |
+
# Resize và normalize image
|
| 60 |
+
if image.mode != 'RGB':
|
| 61 |
+
image = image.convert('RGB')
|
| 62 |
+
|
| 63 |
+
# Resize to model input size
|
| 64 |
+
image = image.resize((224, 224))
|
| 65 |
+
|
| 66 |
+
# Convert to tensor
|
| 67 |
+
import torchvision.transforms as transforms
|
| 68 |
+
transform = transforms.Compose([
|
| 69 |
+
transforms.ToTensor(),
|
| 70 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 71 |
+
std=[0.229, 0.224, 0.225])
|
| 72 |
+
])
|
| 73 |
+
|
| 74 |
+
image_tensor = transform(image).unsqueeze(0)
|
| 75 |
+
return image_tensor
|
| 76 |
+
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logging.error(f"Error preprocessing image: {str(e)}")
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
def predict_emotion(image_tensor, prompt="What emotion is shown in this image?"):
|
| 82 |
+
"""Predict emotion từ image"""
|
| 83 |
+
global model, device
|
| 84 |
+
|
| 85 |
+
if model is None:
|
| 86 |
+
return "Model not loaded"
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
# Prepare samples
|
| 91 |
+
samples = {
|
| 92 |
+
"image": image_tensor.to(device),
|
| 93 |
+
"text_input": [prompt]
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
# Generate prediction
|
| 97 |
+
result = model.generate(
|
| 98 |
+
samples,
|
| 99 |
+
use_nucleus_sampling=False,
|
| 100 |
+
num_beams=3,
|
| 101 |
+
max_length=50,
|
| 102 |
+
min_length=1,
|
| 103 |
+
temperature=0.1,
|
| 104 |
+
repetition_penalty=1.1
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
return result[0] if result else "Unable to predict emotion"
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logging.error(f"Error predicting emotion: {str(e)}")
|
| 111 |
+
return f"Error: {str(e)}"
|
| 112 |
+
|
| 113 |
+
@app.route('/')
|
| 114 |
+
def index():
|
| 115 |
+
"""Home page"""
|
| 116 |
+
return render_template('index.html')
|
| 117 |
+
|
| 118 |
+
@app.route('/predict', methods=['POST'])
|
| 119 |
+
def predict():
|
| 120 |
+
"""Handle image upload and prediction"""
|
| 121 |
+
try:
|
| 122 |
+
if 'image' not in request.files:
|
| 123 |
+
return jsonify({'error': 'No image file provided'}), 400
|
| 124 |
+
|
| 125 |
+
file = request.files['image']
|
| 126 |
+
if file.filename == '':
|
| 127 |
+
return jsonify({'error': 'No image selected'}), 400
|
| 128 |
+
|
| 129 |
+
# Đọc và xử lý image
|
| 130 |
+
image = Image.open(io.BytesIO(file.read()))
|
| 131 |
+
|
| 132 |
+
# Preprocess image
|
| 133 |
+
image_tensor = preprocess_image(image)
|
| 134 |
+
if image_tensor is None:
|
| 135 |
+
return jsonify({'error': 'Failed to process image'}), 400
|
| 136 |
+
|
| 137 |
+
# Get custom prompt if provided
|
| 138 |
+
custom_prompt = request.form.get('prompt', 'What emotion is shown in this image?')
|
| 139 |
+
|
| 140 |
+
# Predict emotion
|
| 141 |
+
emotion_result = predict_emotion(image_tensor, custom_prompt)
|
| 142 |
+
|
| 143 |
+
# Convert image to base64 for display
|
| 144 |
+
buffered = io.BytesIO()
|
| 145 |
+
image.save(buffered, format="PNG")
|
| 146 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 147 |
+
|
| 148 |
+
return jsonify({
|
| 149 |
+
'success': True,
|
| 150 |
+
'emotion': emotion_result,
|
| 151 |
+
'image': img_str,
|
| 152 |
+
'prompt': custom_prompt
|
| 153 |
+
})
|
| 154 |
+
|
| 155 |
+
except Exception as e:
|
| 156 |
+
logging.error(f"Error in prediction: {str(e)}")
|
| 157 |
+
return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
|
| 158 |
+
|
| 159 |
+
@app.route('/health')
|
| 160 |
+
def health():
|
| 161 |
+
"""Health check endpoint"""
|
| 162 |
+
return jsonify({
|
| 163 |
+
'status': 'healthy',
|
| 164 |
+
'model_loaded': model is not None,
|
| 165 |
+
'device': str(device) if device else 'unknown'
|
| 166 |
+
})
|
| 167 |
+
|
| 168 |
+
if __name__ == '__main__':
|
| 169 |
+
# Setup logging
|
| 170 |
+
logging.basicConfig(level=logging.INFO)
|
| 171 |
+
|
| 172 |
+
# Load model
|
| 173 |
+
logging.info("Loading model...")
|
| 174 |
+
load_model()
|
| 175 |
+
|
| 176 |
+
# Determine port for Hugging Face Spaces
|
| 177 |
+
port = int(os.environ.get("PORT", 7860))
|
| 178 |
+
|
| 179 |
+
# Run app
|
| 180 |
+
app.run(host="0.0.0.0", port=port, debug=False)
|
blip2_vicuna_instruct.py
ADDED
|
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Requires Transformer 4.28 and above, implementation may change according the Llama implementation
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
import string
|
| 6 |
+
from packaging import version
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.cuda.amp import autocast as autocast
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
import transformers
|
| 13 |
+
|
| 14 |
+
from lavis.common.registry import registry
|
| 15 |
+
from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
|
| 16 |
+
|
| 17 |
+
@registry.register_model("blip2_vicuna_instruct")
|
| 18 |
+
class Blip2VicunaInstruct(Blip2Base):
|
| 19 |
+
"""
|
| 20 |
+
BLIP2 Vicuna model.
|
| 21 |
+
Supported model types:
|
| 22 |
+
- vicuna7b
|
| 23 |
+
- vicuna13b
|
| 24 |
+
Usage:
|
| 25 |
+
>>> from lavis.models import load_model
|
| 26 |
+
>>> model = load_model("blip2_vicuna_instruct", "vicuna7b")
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
PRETRAINED_MODEL_CONFIG_DICT = {
|
| 30 |
+
"vicuna7b": "configs/models/blip2/blip2_instruct_vicuna7b.yaml",
|
| 31 |
+
"vicuna13b": "configs/models/blip2/blip2_instruct_vicuna13b.yaml",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
vit_model="eva_clip_g",
|
| 37 |
+
img_size=224,
|
| 38 |
+
drop_path_rate=0,
|
| 39 |
+
use_grad_checkpoint=False,
|
| 40 |
+
vit_precision="fp16",
|
| 41 |
+
freeze_vit=True,
|
| 42 |
+
num_query_token=32,
|
| 43 |
+
llm_model="",
|
| 44 |
+
prompt="",
|
| 45 |
+
max_txt_len=128,
|
| 46 |
+
max_output_txt_len=256,
|
| 47 |
+
apply_lemmatizer=False,
|
| 48 |
+
qformer_text_input=True,
|
| 49 |
+
):
|
| 50 |
+
super().__init__()
|
| 51 |
+
transformers_version = version.parse(transformers.__version__)
|
| 52 |
+
assert transformers_version >= version.parse("4.28"), "BLIP-2 Vicuna requires transformers>=4.28"
|
| 53 |
+
from transformers import LlamaTokenizer
|
| 54 |
+
from lavis.models.blip2_models.modeling_llama import LlamaForCausalLM
|
| 55 |
+
|
| 56 |
+
self.tokenizer = self.init_tokenizer(truncation_side="left")
|
| 57 |
+
|
| 58 |
+
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
| 59 |
+
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
| 60 |
+
)
|
| 61 |
+
if freeze_vit:
|
| 62 |
+
for name, param in self.visual_encoder.named_parameters():
|
| 63 |
+
param.requires_grad = False
|
| 64 |
+
self.visual_encoder = self.visual_encoder.eval()
|
| 65 |
+
self.visual_encoder.train = disabled_train
|
| 66 |
+
logging.info("freeze vision encoder")
|
| 67 |
+
|
| 68 |
+
self.Qformer, self.query_tokens = self.init_Qformer(
|
| 69 |
+
num_query_token, self.visual_encoder.num_features
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if not qformer_text_input:
|
| 73 |
+
self.Qformer.bert.embeddings.word_embeddings = None
|
| 74 |
+
self.Qformer.bert.embeddings.position_embeddings = None
|
| 75 |
+
for layer in self.Qformer.bert.encoder.layer:
|
| 76 |
+
layer.output = None
|
| 77 |
+
layer.intermediate = None
|
| 78 |
+
else:
|
| 79 |
+
self.Qformer.resize_token_embeddings(len(self.tokenizer))
|
| 80 |
+
self.Qformer.cls = None
|
| 81 |
+
|
| 82 |
+
self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_model, use_fast=False, truncation_side="left")
|
| 83 |
+
self.llm_model = LlamaForCausalLM.from_pretrained(
|
| 84 |
+
llm_model, torch_dtype=torch.float16
|
| 85 |
+
)
|
| 86 |
+
self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
| 87 |
+
self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
|
| 88 |
+
self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
|
| 89 |
+
self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'})
|
| 90 |
+
# self.llm_tokenizer.pad_token = self.llm_tokenizer.unk_token
|
| 91 |
+
|
| 92 |
+
self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
|
| 93 |
+
|
| 94 |
+
# self.eos_token_id = self.llm_tokenizer(
|
| 95 |
+
# self.llm_tokenizer.eos_token, add_special_tokens=False
|
| 96 |
+
# ).input_ids[0]
|
| 97 |
+
|
| 98 |
+
for name, param in self.llm_model.named_parameters():
|
| 99 |
+
param.requires_grad = False
|
| 100 |
+
|
| 101 |
+
self.llm_proj = nn.Linear(
|
| 102 |
+
self.Qformer.config.hidden_size, self.llm_model.config.hidden_size
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self.max_txt_len = max_txt_len
|
| 106 |
+
self.max_output_txt_len = max_output_txt_len
|
| 107 |
+
self.prompt = prompt
|
| 108 |
+
prompt_tokens = self.llm_tokenizer(self.prompt, return_tensors="pt")
|
| 109 |
+
self.prompt_length = prompt_tokens.attention_mask.sum(1)
|
| 110 |
+
|
| 111 |
+
self._lemmatizer = None
|
| 112 |
+
|
| 113 |
+
self.qformer_text_input = qformer_text_input
|
| 114 |
+
|
| 115 |
+
def concat_text_input_output(self, input_ids, input_atts, output_ids, output_atts):
|
| 116 |
+
input_part_targets_len = []
|
| 117 |
+
llm_tokens = {"input_ids": [], "attention_mask": []}
|
| 118 |
+
for i in range(input_ids.size(0)):
|
| 119 |
+
this_input_ones = input_atts[i].sum()
|
| 120 |
+
input_part_targets_len.append(this_input_ones)
|
| 121 |
+
llm_tokens['input_ids'].append(
|
| 122 |
+
torch.cat([
|
| 123 |
+
input_ids[i][:this_input_ones],
|
| 124 |
+
output_ids[i][1:],
|
| 125 |
+
input_ids[i][this_input_ones:]
|
| 126 |
+
])
|
| 127 |
+
)
|
| 128 |
+
llm_tokens['attention_mask'].append(
|
| 129 |
+
torch.cat([
|
| 130 |
+
input_atts[i][:this_input_ones],
|
| 131 |
+
output_atts[i][1:],
|
| 132 |
+
input_atts[i][this_input_ones:]
|
| 133 |
+
])
|
| 134 |
+
)
|
| 135 |
+
llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids'])
|
| 136 |
+
llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask'])
|
| 137 |
+
return llm_tokens, input_part_targets_len
|
| 138 |
+
|
| 139 |
+
def forward(self, samples):
|
| 140 |
+
#print('-----------------')
|
| 141 |
+
#print(samples["text_input"])
|
| 142 |
+
#print(samples["text_output"])
|
| 143 |
+
#print('-----------------')
|
| 144 |
+
#print(samples)
|
| 145 |
+
#print(samples["text_input"])
|
| 146 |
+
#print(samples["answer"])
|
| 147 |
+
#sss
|
| 148 |
+
|
| 149 |
+
image = samples["image"]
|
| 150 |
+
with self.maybe_autocast():
|
| 151 |
+
image_embeds = self.ln_vision(self.visual_encoder(image))
|
| 152 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
| 153 |
+
|
| 154 |
+
bs = image.size(0)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
for i in range(len(samples["text_input"])):
|
| 158 |
+
samples["answer"][i] = samples["answer"][i].replace('Complex','')
|
| 159 |
+
#samples["text_input"][i] = samples["text_input"][i].replace(' predict emotion',': Predict emotion:')
|
| 160 |
+
#samples["text_input"][i] = samples["text_input"][i].lstrip().capitalize()
|
| 161 |
+
#print(samples["text_input"])
|
| 162 |
+
#print(samples["answer"])
|
| 163 |
+
#print('----------------')
|
| 164 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 165 |
+
if self.qformer_text_input:
|
| 166 |
+
text_Qformer = self.tokenizer(
|
| 167 |
+
samples["text_input"],
|
| 168 |
+
padding='longest',
|
| 169 |
+
truncation=True,
|
| 170 |
+
max_length=self.max_txt_len,
|
| 171 |
+
return_tensors="pt",
|
| 172 |
+
).to(image.device)
|
| 173 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
|
| 174 |
+
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],dim=1)
|
| 175 |
+
|
| 176 |
+
query_output = self.Qformer.bert(
|
| 177 |
+
text_Qformer.input_ids,
|
| 178 |
+
attention_mask=Qformer_atts,
|
| 179 |
+
query_embeds=query_tokens,
|
| 180 |
+
encoder_hidden_states=image_embeds,
|
| 181 |
+
encoder_attention_mask=image_atts,
|
| 182 |
+
return_dict=True,
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
query_output = self.Qformer.bert(
|
| 186 |
+
query_embeds=query_tokens,
|
| 187 |
+
encoder_hidden_states=image_embeds,
|
| 188 |
+
encoder_attention_mask=image_atts,
|
| 189 |
+
return_dict=True,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
inputs_llm = self.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
| 193 |
+
atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
|
| 194 |
+
|
| 195 |
+
self.llm_tokenizer.padding_side = "right"
|
| 196 |
+
self.llm_tokenizer.truncation_side = 'left'
|
| 197 |
+
text_input_tokens = self.llm_tokenizer(
|
| 198 |
+
samples['text_input'],
|
| 199 |
+
return_tensors="pt",
|
| 200 |
+
padding="longest",
|
| 201 |
+
truncation=True,
|
| 202 |
+
max_length=self.max_txt_len,
|
| 203 |
+
).to(image.device)
|
| 204 |
+
|
| 205 |
+
self.llm_tokenizer.truncation_side = 'right'
|
| 206 |
+
text_output_tokens = self.llm_tokenizer(
|
| 207 |
+
[t + self.llm_tokenizer.eos_token for t in samples['answer']],
|
| 208 |
+
return_tensors="pt",
|
| 209 |
+
padding="longest",
|
| 210 |
+
truncation=True,
|
| 211 |
+
max_length=self.max_output_txt_len,
|
| 212 |
+
).to(image.device)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
llm_tokens, input_part_targets_len = self.concat_text_input_output(
|
| 216 |
+
text_input_tokens.input_ids,
|
| 217 |
+
text_input_tokens.attention_mask,
|
| 218 |
+
text_output_tokens.input_ids,
|
| 219 |
+
text_output_tokens.attention_mask,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# do not apply loss to the padding
|
| 223 |
+
targets = llm_tokens['input_ids'].masked_fill(
|
| 224 |
+
llm_tokens['input_ids'] == self.llm_tokenizer.pad_token_id, -100
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# do not apply loss to the text input (i.e., instruction)
|
| 228 |
+
for i, l in enumerate(input_part_targets_len):
|
| 229 |
+
targets[i][:l] = -100
|
| 230 |
+
|
| 231 |
+
# do not apply loss to the query tokens
|
| 232 |
+
empty_targets = (
|
| 233 |
+
torch.ones(atts_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
|
| 234 |
+
)
|
| 235 |
+
targets = torch.cat([empty_targets, targets], dim=1)
|
| 236 |
+
|
| 237 |
+
inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens['input_ids'])
|
| 238 |
+
inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
|
| 239 |
+
attention_mask = torch.cat([atts_llm, llm_tokens['attention_mask']], dim=1)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
with self.maybe_autocast():
|
| 243 |
+
outputs = self.llm_model(
|
| 244 |
+
inputs_embeds=inputs_embeds,
|
| 245 |
+
attention_mask=attention_mask,
|
| 246 |
+
return_dict=True,
|
| 247 |
+
labels=targets,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
loss = outputs.loss
|
| 252 |
+
|
| 253 |
+
return {"loss": loss}
|
| 254 |
+
|
| 255 |
+
@torch.no_grad()
|
| 256 |
+
def generate(
|
| 257 |
+
self,
|
| 258 |
+
samples,
|
| 259 |
+
use_nucleus_sampling=False,
|
| 260 |
+
num_beams=1, #5
|
| 261 |
+
max_length=256, #256
|
| 262 |
+
min_length=1,
|
| 263 |
+
top_p=0.9,
|
| 264 |
+
repetition_penalty=5.0, #1.5
|
| 265 |
+
length_penalty=1,
|
| 266 |
+
num_captions=1,
|
| 267 |
+
temperature=0, #1
|
| 268 |
+
):
|
| 269 |
+
self.llm_tokenizer.padding_side = "left"
|
| 270 |
+
|
| 271 |
+
if "prompt" in samples.keys():
|
| 272 |
+
prompt = samples["prompt"]
|
| 273 |
+
else:
|
| 274 |
+
prompt = self.prompt
|
| 275 |
+
|
| 276 |
+
image = samples["image"]
|
| 277 |
+
|
| 278 |
+
bs = image.size(0)
|
| 279 |
+
|
| 280 |
+
if isinstance(prompt, str):
|
| 281 |
+
prompt = [prompt] * bs
|
| 282 |
+
else:
|
| 283 |
+
assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# For TextCaps
|
| 287 |
+
if "ocr_tokens" in samples.keys() and "{}" in prompt[0]:
|
| 288 |
+
prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)]
|
| 289 |
+
|
| 290 |
+
query_tokens = self.query_tokens.expand(bs, -1, -1)
|
| 291 |
+
if self.qformer_text_input:
|
| 292 |
+
# remove ocr tokens in q_former (for eval textvqa)
|
| 293 |
+
# qformer_prompt = prompt
|
| 294 |
+
# qformer_prompt = ['Question: ' + qp.split(' Question: ')[1] for qp in qformer_prompt]
|
| 295 |
+
|
| 296 |
+
text_Qformer = self.tokenizer(
|
| 297 |
+
prompt,
|
| 298 |
+
padding='longest',
|
| 299 |
+
truncation=True,
|
| 300 |
+
max_length=self.max_txt_len,
|
| 301 |
+
return_tensors="pt",
|
| 302 |
+
).to(image.device)
|
| 303 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
|
| 304 |
+
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
|
| 305 |
+
|
| 306 |
+
# For video data
|
| 307 |
+
if image.dim() == 5:
|
| 308 |
+
inputs_llm, atts_llm = [], []
|
| 309 |
+
for j in range(image.size(2)):
|
| 310 |
+
this_frame = image[:,:,j,:,:]
|
| 311 |
+
with self.maybe_autocast():
|
| 312 |
+
frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
|
| 313 |
+
frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
| 314 |
+
|
| 315 |
+
if self.qformer_text_input:
|
| 316 |
+
frame_query_output = self.Qformer.bert(
|
| 317 |
+
text_Qformer.input_ids,
|
| 318 |
+
attention_mask=Qformer_atts,
|
| 319 |
+
query_embeds=query_tokens,
|
| 320 |
+
encoder_hidden_states=frame_embeds,
|
| 321 |
+
encoder_attention_mask=frame_atts,
|
| 322 |
+
return_dict=True,
|
| 323 |
+
)
|
| 324 |
+
else:
|
| 325 |
+
frame_query_output = self.Qformer.bert(
|
| 326 |
+
query_embeds=query_tokens,
|
| 327 |
+
encoder_hidden_states=frame_embeds,
|
| 328 |
+
encoder_attention_mask=frame_atts,
|
| 329 |
+
return_dict=True,
|
| 330 |
+
)
|
| 331 |
+
frame_inputs_llm = self.llm_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
| 332 |
+
frame_atts_llm = torch.ones(frame_inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
|
| 333 |
+
inputs_llm.append(frame_inputs_llm)
|
| 334 |
+
atts_llm.append(frame_atts_llm)
|
| 335 |
+
inputs_llm = torch.cat(inputs_llm, dim=1)
|
| 336 |
+
atts_llm = torch.cat(atts_llm, dim=1)
|
| 337 |
+
else:
|
| 338 |
+
with self.maybe_autocast():
|
| 339 |
+
image_embeds = self.ln_vision(self.visual_encoder(image))
|
| 340 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
| 341 |
+
|
| 342 |
+
if self.qformer_text_input:
|
| 343 |
+
query_output = self.Qformer.bert(
|
| 344 |
+
text_Qformer.input_ids,
|
| 345 |
+
attention_mask=Qformer_atts,
|
| 346 |
+
query_embeds=query_tokens,
|
| 347 |
+
encoder_hidden_states=image_embeds,
|
| 348 |
+
encoder_attention_mask=image_atts,
|
| 349 |
+
return_dict=True,
|
| 350 |
+
)
|
| 351 |
+
else:
|
| 352 |
+
query_output = self.Qformer.bert(
|
| 353 |
+
query_embeds=query_tokens,
|
| 354 |
+
encoder_hidden_states=image_embeds,
|
| 355 |
+
encoder_attention_mask=image_atts,
|
| 356 |
+
return_dict=True,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
inputs_llm = self.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
| 360 |
+
atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
|
| 361 |
+
|
| 362 |
+
llm_tokens = self.llm_tokenizer(
|
| 363 |
+
prompt,
|
| 364 |
+
padding="longest",
|
| 365 |
+
return_tensors="pt"
|
| 366 |
+
).to(image.device)
|
| 367 |
+
|
| 368 |
+
with self.maybe_autocast():
|
| 369 |
+
inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids)
|
| 370 |
+
inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
|
| 371 |
+
attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1)
|
| 372 |
+
|
| 373 |
+
outputs = self.llm_model.generate(
|
| 374 |
+
inputs_embeds=inputs_embeds,
|
| 375 |
+
attention_mask=attention_mask,
|
| 376 |
+
do_sample=use_nucleus_sampling,
|
| 377 |
+
top_p=top_p,
|
| 378 |
+
temperature=temperature,
|
| 379 |
+
num_beams=num_beams,
|
| 380 |
+
max_length=max_length,
|
| 381 |
+
min_length=min_length,
|
| 382 |
+
# eos_token_id=self.eos_token_id,
|
| 383 |
+
repetition_penalty=repetition_penalty,
|
| 384 |
+
length_penalty=length_penalty,
|
| 385 |
+
num_return_sequences=num_captions,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
|
| 389 |
+
output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 390 |
+
output_text = [text.strip() for text in output_text]
|
| 391 |
+
|
| 392 |
+
return output_text
|
| 393 |
+
|
| 394 |
+
def predict_answers(
|
| 395 |
+
self,
|
| 396 |
+
samples,
|
| 397 |
+
num_beams=5,
|
| 398 |
+
inference_method="generate",
|
| 399 |
+
max_len=10,
|
| 400 |
+
min_len=1,
|
| 401 |
+
num_ans_candidates=128,
|
| 402 |
+
answer_list=None,
|
| 403 |
+
prompt="",
|
| 404 |
+
length_penalty=0,
|
| 405 |
+
**kwargs
|
| 406 |
+
):
|
| 407 |
+
if isinstance(samples["text_input"], str):
|
| 408 |
+
samples["text_input"] = [samples["text_input"]]
|
| 409 |
+
|
| 410 |
+
if prompt:
|
| 411 |
+
if prompt.count("{}") == 2:
|
| 412 |
+
if 'ocr_tokens' in samples:
|
| 413 |
+
text_input = [
|
| 414 |
+
prompt.format(', '.join(samples['ocr_tokens'][i][:30]), samples["text_input"][i])
|
| 415 |
+
for i in range(len(samples["text_input"]))]
|
| 416 |
+
elif 'choices' in samples:
|
| 417 |
+
text_input = []
|
| 418 |
+
for i in range(len(samples["text_input"])):
|
| 419 |
+
this_choices = [f"({string.ascii_lowercase[j]}) {ch}" for j, ch in enumerate(samples["choices"][i])]
|
| 420 |
+
this_choices = " ".join(this_choices)
|
| 421 |
+
text_input.append(prompt.format(samples["text_input"][i], this_choices))
|
| 422 |
+
else:
|
| 423 |
+
text_input = [prompt.format(question) for question in samples["text_input"]]
|
| 424 |
+
else:
|
| 425 |
+
text_input = samples["text_input"]
|
| 426 |
+
|
| 427 |
+
samples["prompt"] = text_input
|
| 428 |
+
|
| 429 |
+
output_text = self.generate(
|
| 430 |
+
samples,
|
| 431 |
+
num_beams=num_beams,
|
| 432 |
+
max_length=max_len,
|
| 433 |
+
min_length=min_len,
|
| 434 |
+
length_penalty=length_penalty
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
|
| 438 |
+
output_text = self._lemmatize(output_text)
|
| 439 |
+
|
| 440 |
+
return output_text
|
| 441 |
+
|
| 442 |
+
def predict_class(
|
| 443 |
+
self,
|
| 444 |
+
samples,
|
| 445 |
+
candidates,
|
| 446 |
+
n_segments=1,
|
| 447 |
+
):
|
| 448 |
+
self.llm_tokenizer.padding_side = "left"
|
| 449 |
+
|
| 450 |
+
# If candidates is a list of lists, each sample has its candidates, then we need to iterate one by one
|
| 451 |
+
if type(candidates[0]) == list:
|
| 452 |
+
results = []
|
| 453 |
+
|
| 454 |
+
for i in range(samples["image"].size(0)):
|
| 455 |
+
this_sample = {
|
| 456 |
+
"image": samples["image"][i].unsqueeze(0),
|
| 457 |
+
"prompt": samples["prompt"],
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
if "text_input" in samples.keys():
|
| 461 |
+
this_sample["text_input"] = [samples["text_input"][i]]
|
| 462 |
+
|
| 463 |
+
if 'context' in samples.keys():
|
| 464 |
+
this_sample['context'] = [samples["context"][i]]
|
| 465 |
+
|
| 466 |
+
if 'history' in samples.keys():
|
| 467 |
+
this_sample['history'] = [samples["history"][i]]
|
| 468 |
+
|
| 469 |
+
if 'caption' in samples.keys():
|
| 470 |
+
this_sample['caption'] = [samples["caption"][i]]
|
| 471 |
+
|
| 472 |
+
this_result = self._predict_class(this_sample, candidates[i], n_segments)
|
| 473 |
+
results.append(this_result)
|
| 474 |
+
|
| 475 |
+
try:
|
| 476 |
+
results = torch.cat(results, dim=0)
|
| 477 |
+
except:
|
| 478 |
+
results = [res.tolist()[0] for res in results]
|
| 479 |
+
|
| 480 |
+
return results
|
| 481 |
+
|
| 482 |
+
return self._predict_class(samples, candidates, n_segments)
|
| 483 |
+
|
| 484 |
+
def _predict_class(
|
| 485 |
+
self,
|
| 486 |
+
samples,
|
| 487 |
+
candidates,
|
| 488 |
+
n_segments=1,
|
| 489 |
+
):
|
| 490 |
+
image = samples["image"]
|
| 491 |
+
prompt = samples["prompt"]
|
| 492 |
+
|
| 493 |
+
bs = image.size(0)
|
| 494 |
+
|
| 495 |
+
if isinstance(prompt, str):
|
| 496 |
+
prompt = [prompt] * bs
|
| 497 |
+
else:
|
| 498 |
+
assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
|
| 499 |
+
|
| 500 |
+
if "text_input" in samples.keys():
|
| 501 |
+
if type(samples["text_input"][0]) == list:
|
| 502 |
+
prompt = [prompt[i].format(*samples["text_input"][i]) for i in range(len(prompt))]
|
| 503 |
+
else:
|
| 504 |
+
prompt = [prompt[i].format(samples["text_input"][i]) for i in range(len(prompt))]
|
| 505 |
+
|
| 506 |
+
# scienceqa
|
| 507 |
+
if 'context' in samples.keys() and samples['context'] != '':
|
| 508 |
+
prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))]
|
| 509 |
+
|
| 510 |
+
# visual dialog
|
| 511 |
+
if 'history' in samples.keys() and samples['history'][0] != '':
|
| 512 |
+
prompt = [f'dialog history: {samples["history"][i]}\n{prompt[i]}' for i in range(len(prompt))]
|
| 513 |
+
|
| 514 |
+
if 'caption' in samples.keys() and samples['caption'][0] != '':
|
| 515 |
+
prompt = [f'This image has the caption "{samples["caption"][i]}". {prompt[i]}' for i in range(len(prompt))]
|
| 516 |
+
|
| 517 |
+
query_tokens = self.query_tokens.expand(bs, -1, -1)
|
| 518 |
+
if self.qformer_text_input:
|
| 519 |
+
text_Qformer = self.tokenizer(
|
| 520 |
+
prompt,
|
| 521 |
+
padding='longest',
|
| 522 |
+
truncation=True,
|
| 523 |
+
max_length=self.max_txt_len,
|
| 524 |
+
return_tensors="pt"
|
| 525 |
+
).to(image.device)
|
| 526 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
|
| 527 |
+
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
|
| 528 |
+
|
| 529 |
+
if image.dim() == 5:
|
| 530 |
+
inputs_llm, atts_llm = [], []
|
| 531 |
+
for j in range(image.size(2)):
|
| 532 |
+
this_frame = image[:,:,j,:,:]
|
| 533 |
+
with self.maybe_autocast():
|
| 534 |
+
frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
|
| 535 |
+
frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
| 536 |
+
|
| 537 |
+
if self.qformer_text_input:
|
| 538 |
+
frame_query_output = self.Qformer.bert(
|
| 539 |
+
text_Qformer.input_ids,
|
| 540 |
+
attention_mask=Qformer_atts,
|
| 541 |
+
query_embeds=query_tokens,
|
| 542 |
+
encoder_hidden_states=frame_embeds,
|
| 543 |
+
encoder_attention_mask=frame_atts,
|
| 544 |
+
return_dict=True,
|
| 545 |
+
)
|
| 546 |
+
else:
|
| 547 |
+
frame_query_output = self.Qformer.bert(
|
| 548 |
+
query_embeds=query_tokens,
|
| 549 |
+
encoder_hidden_states=frame_embeds,
|
| 550 |
+
encoder_attention_mask=frame_atts,
|
| 551 |
+
return_dict=True,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
frame_inputs_llm = self.llm_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
| 555 |
+
frame_atts_llm = torch.ones(frame_inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
|
| 556 |
+
inputs_llm.append(frame_inputs_llm)
|
| 557 |
+
atts_llm.append(frame_atts_llm)
|
| 558 |
+
inputs_llm = torch.cat(inputs_llm, dim=1)
|
| 559 |
+
atts_llm = torch.cat(atts_llm, dim=1)
|
| 560 |
+
else:
|
| 561 |
+
with self.maybe_autocast():
|
| 562 |
+
image_embeds = self.ln_vision(self.visual_encoder(image))
|
| 563 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
| 564 |
+
|
| 565 |
+
if self.qformer_text_input:
|
| 566 |
+
query_output = self.Qformer.bert(
|
| 567 |
+
text_Qformer.input_ids,
|
| 568 |
+
attention_mask=Qformer_atts,
|
| 569 |
+
query_embeds=query_tokens,
|
| 570 |
+
encoder_hidden_states=image_embeds,
|
| 571 |
+
encoder_attention_mask=image_atts,
|
| 572 |
+
return_dict=True,
|
| 573 |
+
)
|
| 574 |
+
else:
|
| 575 |
+
query_output = self.Qformer.bert(
|
| 576 |
+
query_embeds=query_tokens,
|
| 577 |
+
encoder_hidden_states=image_embeds,
|
| 578 |
+
encoder_attention_mask=image_atts,
|
| 579 |
+
return_dict=True,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
inputs_llm = self.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
| 583 |
+
atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
|
| 584 |
+
|
| 585 |
+
self.llm_tokenizer.padding_side = "right"
|
| 586 |
+
self.llm_tokenizer.truncation_side = 'left'
|
| 587 |
+
text_input_tokens = self.llm_tokenizer(
|
| 588 |
+
prompt,
|
| 589 |
+
return_tensors="pt",
|
| 590 |
+
padding="longest",
|
| 591 |
+
# truncation=True,
|
| 592 |
+
# max_length=self.max_txt_len,
|
| 593 |
+
).to(image.device)
|
| 594 |
+
|
| 595 |
+
empty_targets = torch.ones(atts_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
|
| 596 |
+
|
| 597 |
+
# self.llm_tokenizer.padding_side = "right"
|
| 598 |
+
self.llm_tokenizer.truncation_side = 'right'
|
| 599 |
+
n_cands = len(candidates)
|
| 600 |
+
with self.maybe_autocast(dtype=torch.bfloat16):
|
| 601 |
+
all_losses = []
|
| 602 |
+
for n in range(n_segments):
|
| 603 |
+
seg_len = n_cands // n_segments
|
| 604 |
+
if n == (n_segments - 1):
|
| 605 |
+
seg_len = n_cands - seg_len * (n_segments - 1)
|
| 606 |
+
|
| 607 |
+
start_i = n * (n_cands // n_segments)
|
| 608 |
+
end_i = start_i + seg_len
|
| 609 |
+
|
| 610 |
+
this_output_tokens = self.llm_tokenizer(
|
| 611 |
+
candidates[start_i:end_i],
|
| 612 |
+
return_tensors="pt",
|
| 613 |
+
padding="longest",
|
| 614 |
+
# truncation=True,
|
| 615 |
+
# max_length=self.max_output_txt_len,
|
| 616 |
+
).to(image.device)
|
| 617 |
+
|
| 618 |
+
this_input_tokens_ids = text_input_tokens.input_ids.repeat_interleave(seg_len, dim=0)
|
| 619 |
+
this_input_tokens_atts = text_input_tokens.attention_mask.repeat_interleave(seg_len, dim=0)
|
| 620 |
+
|
| 621 |
+
this_output_tokens_ids = this_output_tokens.input_ids.repeat(bs, 1)
|
| 622 |
+
this_output_tokens_atts = this_output_tokens.attention_mask.repeat(bs, 1)
|
| 623 |
+
|
| 624 |
+
this_llm_tokens, this_input_targets_len = self.concat_text_input_output(
|
| 625 |
+
this_input_tokens_ids,
|
| 626 |
+
this_input_tokens_atts,
|
| 627 |
+
this_output_tokens_ids,
|
| 628 |
+
this_output_tokens_atts
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
this_llm_input_ids = this_llm_tokens['input_ids']
|
| 632 |
+
this_llm_atts = this_llm_tokens['attention_mask']
|
| 633 |
+
# this_llm_input_ids = torch.cat([this_input_tokens_ids, this_output_tokens_ids], dim=1)
|
| 634 |
+
# this_llm_atts = torch.cat([this_input_tokens_atts, this_output_tokens_atts], dim=1)
|
| 635 |
+
|
| 636 |
+
inputs_embeds = self.llm_model.get_input_embeddings()(this_llm_input_ids)
|
| 637 |
+
inputs_embeds = torch.cat([inputs_llm.repeat_interleave(seg_len, dim=0), inputs_embeds], dim=1)
|
| 638 |
+
attention_mask = torch.cat([atts_llm.repeat_interleave(seg_len, dim=0), this_llm_atts], dim=1)
|
| 639 |
+
|
| 640 |
+
this_targets = this_llm_input_ids.masked_fill(this_llm_input_ids == self.llm_tokenizer.pad_token_id, -100)
|
| 641 |
+
# this_targets[:, :this_input_tokens_ids.size(1)] = -100
|
| 642 |
+
for i, l in enumerate(this_input_targets_len):
|
| 643 |
+
this_targets[i][:l] = -100
|
| 644 |
+
|
| 645 |
+
this_targets = torch.cat([empty_targets.repeat_interleave(seg_len, dim=0), this_targets], dim=1)
|
| 646 |
+
|
| 647 |
+
outputs = self.llm_model(
|
| 648 |
+
inputs_embeds=inputs_embeds,
|
| 649 |
+
attention_mask=attention_mask,
|
| 650 |
+
return_dict=True,
|
| 651 |
+
labels=this_targets,
|
| 652 |
+
reduction="none",
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
loss = outputs.loss
|
| 656 |
+
|
| 657 |
+
loss = loss.reshape(bs, seg_len)
|
| 658 |
+
# output_class_ranks = torch.argsort(loss, dim=-1)
|
| 659 |
+
all_losses.append(loss)
|
| 660 |
+
|
| 661 |
+
all_losses = torch.cat(all_losses, dim=-1)
|
| 662 |
+
output_class_ranks = torch.argsort(all_losses, dim=-1)
|
| 663 |
+
|
| 664 |
+
return output_class_ranks
|
| 665 |
+
|
| 666 |
+
def _lemmatize(self, answers):
|
| 667 |
+
def apply(answer):
|
| 668 |
+
doc = self.lemmatizer(answer)
|
| 669 |
+
|
| 670 |
+
words = []
|
| 671 |
+
for token in doc:
|
| 672 |
+
if token.pos_ in ["NOUN", "VERB"]:
|
| 673 |
+
words.append(token.lemma_)
|
| 674 |
+
else:
|
| 675 |
+
words.append(token.text)
|
| 676 |
+
answer = " ".join(words)
|
| 677 |
+
|
| 678 |
+
return answer
|
| 679 |
+
|
| 680 |
+
return [apply(answer) for answer in answers]
|
| 681 |
+
|
| 682 |
+
@property
|
| 683 |
+
def lemmatizer(self):
|
| 684 |
+
if self._lemmatizer is None:
|
| 685 |
+
try:
|
| 686 |
+
import spacy
|
| 687 |
+
|
| 688 |
+
self._lemmatizer = spacy.load("en_core_web_sm")
|
| 689 |
+
except ImportError:
|
| 690 |
+
logging.error(
|
| 691 |
+
"""
|
| 692 |
+
Please install spacy and en_core_web_sm model to apply lemmatization.
|
| 693 |
+
python -m spacy download en_core_web_sm
|
| 694 |
+
OR
|
| 695 |
+
import spacy.cli
|
| 696 |
+
spacy.cli.download("en_core_web_sm")
|
| 697 |
+
"""
|
| 698 |
+
)
|
| 699 |
+
exit(1)
|
| 700 |
+
|
| 701 |
+
return self._lemmatizer
|
| 702 |
+
|
| 703 |
+
@classmethod
|
| 704 |
+
def from_config(cls, cfg):
|
| 705 |
+
vit_model = cfg.get("vit_model", "eva_clip_g")
|
| 706 |
+
img_size = cfg.get("image_size")
|
| 707 |
+
num_query_token = cfg.get("num_query_token")
|
| 708 |
+
llm_model = cfg.get("llm_model")
|
| 709 |
+
|
| 710 |
+
drop_path_rate = cfg.get("drop_path_rate", 0)
|
| 711 |
+
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
| 712 |
+
vit_precision = cfg.get("vit_precision", "fp16")
|
| 713 |
+
freeze_vit = cfg.get("freeze_vit", True)
|
| 714 |
+
|
| 715 |
+
prompt = cfg.get("prompt", "")
|
| 716 |
+
max_txt_len = cfg.get("max_txt_len", 128)
|
| 717 |
+
max_output_txt_len = cfg.get("max_output_txt_len", 256)
|
| 718 |
+
|
| 719 |
+
apply_lemmatizer = cfg.get("apply_lemmatizer", False)
|
| 720 |
+
|
| 721 |
+
qformer_text_input = cfg.get("qformer_text_input", True)
|
| 722 |
+
|
| 723 |
+
model = cls(
|
| 724 |
+
vit_model=vit_model,
|
| 725 |
+
img_size=img_size,
|
| 726 |
+
drop_path_rate=drop_path_rate,
|
| 727 |
+
use_grad_checkpoint=use_grad_checkpoint,
|
| 728 |
+
vit_precision=vit_precision,
|
| 729 |
+
freeze_vit=freeze_vit,
|
| 730 |
+
num_query_token=num_query_token,
|
| 731 |
+
llm_model=llm_model,
|
| 732 |
+
prompt=prompt,
|
| 733 |
+
max_txt_len=max_txt_len,
|
| 734 |
+
max_output_txt_len=max_output_txt_len,
|
| 735 |
+
apply_lemmatizer=apply_lemmatizer,
|
| 736 |
+
qformer_text_input=qformer_text_input,
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
# if qformer_text_input:
|
| 740 |
+
# # Hard-coded to load from BLIP-2 stage-1 pre-trained model (not ideal)
|
| 741 |
+
# model.load_from_pretrained(
|
| 742 |
+
# url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth"
|
| 743 |
+
# )
|
| 744 |
+
|
| 745 |
+
model.load_checkpoint_from_config(cfg)
|
| 746 |
+
|
| 747 |
+
return model
|
emo/all.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import shutil
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
out = []
|
| 10 |
+
# reasoning
|
| 11 |
+
folder_path_reasoning = './emo/reasoning/'
|
| 12 |
+
filelist_reasoning = os.listdir(folder_path_reasoning)
|
| 13 |
+
|
| 14 |
+
for class_name in filelist_reasoning:
|
| 15 |
+
path = os.path.join(folder_path_reasoning, class_name)
|
| 16 |
+
item = os.listdir(path)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
for name in item:
|
| 20 |
+
with open(folder_path_reasoning + class_name + '/' + name, 'r', encoding='utf-8') as file:
|
| 21 |
+
text = file.read()
|
| 22 |
+
pattern = r"(?i)Question\s*:(.*?)\s*Answer\s*:(.*?)(?=\s*(Question\s*:|Answer\s*:|$))"
|
| 23 |
+
|
| 24 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
| 25 |
+
reasoning = []
|
| 26 |
+
|
| 27 |
+
for match in matches:
|
| 28 |
+
question = match[0].strip()
|
| 29 |
+
answer = match[1].strip()
|
| 30 |
+
reasoning.append({"from": "human", "value": question})
|
| 31 |
+
reasoning.append({"from": "gpt", "value": answer})
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
for i in range(int(len(reasoning)/2)):
|
| 35 |
+
out.append({"id": name.split('_')[1][:5], "image": name.split('.')[0] + '.jpg', 'conversations': reasoning[2*i:2*i+2]})
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# conversation
|
| 39 |
+
|
| 40 |
+
folder_path = './emo/conversation/'
|
| 41 |
+
filelist = os.listdir(folder_path)
|
| 42 |
+
for class_name in filelist:
|
| 43 |
+
path = os.path.join(folder_path, class_name)
|
| 44 |
+
item = os.listdir(path)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
for name in item:
|
| 48 |
+
with open(folder_path + class_name + '/' + name, 'r', encoding='utf-8') as file:
|
| 49 |
+
text = file.read()
|
| 50 |
+
|
| 51 |
+
pattern = r"(?i)Question\s*\d*:(.*?)\s*Answer\s*\d*:(.*?)\s*(?=(Question:\d*|Complex Question:\d*|Complex question:\d*|$))"
|
| 52 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
| 53 |
+
conversations = []
|
| 54 |
+
|
| 55 |
+
for match in matches:
|
| 56 |
+
question = match[0].strip()
|
| 57 |
+
answer = match[1].strip()
|
| 58 |
+
conversations.append({"from": "human", "value": question})
|
| 59 |
+
conversations.append({"from": "gpt", "value": answer})
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
#conversations[0]['value'] = conversations[0]['value'] + '\n<image>'
|
| 63 |
+
conversations[0]['value'] = conversations[0]['value']
|
| 64 |
+
|
| 65 |
+
#out.append({"id": name.split('_')[1][:5], "image": name.split('.')[0] + '.jpg', 'conversations': conversations})
|
| 66 |
+
|
| 67 |
+
for i in range(int(len(conversations)/2)):
|
| 68 |
+
out.append({"id": name.split('_')[1][:5], "image": name.split('.')[0] + '.jpg', 'conversations': conversations[2*i:2*i+2]})
|
| 69 |
+
|
| 70 |
+
shutil.copy('./emo/image/' + class_name + '/' + name[:-3] + 'jpg', './emo/image/train_image')
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
##### classification
|
| 74 |
+
with open('./emo/train.json', 'r') as json_file:
|
| 75 |
+
json_data = json.load(json_file)
|
| 76 |
+
|
| 77 |
+
amusement_data = []
|
| 78 |
+
anger_data = []
|
| 79 |
+
awe_data = []
|
| 80 |
+
contentment_data = []
|
| 81 |
+
disgust_data = []
|
| 82 |
+
excitement_data = []
|
| 83 |
+
fear_data = []
|
| 84 |
+
sadness_data = []
|
| 85 |
+
|
| 86 |
+
for item in json_data:
|
| 87 |
+
category = item[0]
|
| 88 |
+
if category == 'amusement':
|
| 89 |
+
amusement_data.append(item[1].split('/')[2][:-4])
|
| 90 |
+
elif category == 'anger':
|
| 91 |
+
anger_data.append(item[1].split('/')[2][:-4])
|
| 92 |
+
elif category == 'awe':
|
| 93 |
+
awe_data.append(item[1].split('/')[2][:-4])
|
| 94 |
+
elif category == 'contentment':
|
| 95 |
+
contentment_data.append(item[1].split('/')[2][:-4])
|
| 96 |
+
elif category == 'disgust':
|
| 97 |
+
disgust_data.append(item[1].split('/')[2][:-4])
|
| 98 |
+
elif category == 'excitement':
|
| 99 |
+
excitement_data.append(item[1].split('/')[2][:-4])
|
| 100 |
+
elif category == 'fear':
|
| 101 |
+
fear_data.append(item[1].split('/')[2][:-4])
|
| 102 |
+
elif category == 'sadness':
|
| 103 |
+
sadness_data.append(item[1].split('/')[2][:-4])
|
| 104 |
+
|
| 105 |
+
all_data = [amusement_data, anger_data, awe_data, contentment_data, disgust_data, excitement_data, fear_data, sadness_data]
|
| 106 |
+
emo = ['amusement', 'anger', 'awe', 'contentment', 'disgust', 'excitement', 'fear', 'sadness']
|
| 107 |
+
|
| 108 |
+
for i in range(8):
|
| 109 |
+
for j in range(1000, 5600):
|
| 110 |
+
word = [
|
| 111 |
+
{
|
| 112 |
+
"from": "human",
|
| 113 |
+
"value": "Please select the emotion closest to the image from the following options:\
|
| 114 |
+
amusement, \
|
| 115 |
+
anger, \
|
| 116 |
+
awe, \
|
| 117 |
+
contentment, \
|
| 118 |
+
disgust, \
|
| 119 |
+
excitement, \
|
| 120 |
+
fear and sadness \
|
| 121 |
+
(Do not provide answers outside of the candidates options.) Please answer in the following format: Predict emotion:"
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"from": "gpt",
|
| 125 |
+
"value": 'Predict emotion: ' + emo[i]
|
| 126 |
+
}
|
| 127 |
+
]
|
| 128 |
+
temp = {'id': all_data[i][j][-5:], 'image': all_data[i][j] + '.jpg', 'conversations': word}
|
| 129 |
+
|
| 130 |
+
out.append(temp)
|
| 131 |
+
|
| 132 |
+
shutil.copy('./emo/image/' + emo[i] + '/' + all_data[i][j] + '.jpg', './emo/image/train_image')
|
| 133 |
+
#####
|
| 134 |
+
|
| 135 |
+
random.shuffle(out)
|
| 136 |
+
with open('./emo/train.json', 'w') as json_file:
|
| 137 |
+
json.dump(out, json_file, indent=2)
|
| 138 |
+
|
emo/cap-anno.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
class_name = 'sadness'
|
| 5 |
+
|
| 6 |
+
path = './emo/caption/' + class_name + '/'
|
| 7 |
+
filelist = os.listdir(path)
|
| 8 |
+
|
| 9 |
+
caption_path = './emo/caption/' + class_name + '/'
|
| 10 |
+
annotation_path = './emo/annotation/' + class_name + '/'
|
| 11 |
+
for name in filelist:
|
| 12 |
+
print(name)
|
| 13 |
+
with open(caption_path + name, 'r', encoding='utf-8') as file:
|
| 14 |
+
caption = file.read()
|
| 15 |
+
with open(annotation_path + name.split('txt')[0] + 'json', 'r') as json_file:
|
| 16 |
+
annotation = json.load(json_file)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
out = caption
|
| 20 |
+
out = out + '\n\n'
|
| 21 |
+
out = out + 'emotion: ' + str(annotation['emotion'])
|
| 22 |
+
if 'brightness' in annotation:
|
| 23 |
+
out = out + '\n' + 'brightness: ' + str(annotation['brightness'])
|
| 24 |
+
if 'colorfulness' in annotation:
|
| 25 |
+
out = out + '\n' + 'colorfulness: ' + str(annotation['colorfulness'])
|
| 26 |
+
|
| 27 |
+
if 'object' in annotation:
|
| 28 |
+
out = out + '\n' + 'object: ' + str(annotation['object'])
|
| 29 |
+
if 'facial_expression' in annotation:
|
| 30 |
+
out = out + '\n' + 'facial_expression: ' + str(annotation['facial_expression'])
|
| 31 |
+
if 'human_action' in annotation:
|
| 32 |
+
out = out + '\n' + 'human_action: ' + str(annotation['human_action'])
|
| 33 |
+
|
| 34 |
+
out_path = "./emo/cap-ano/" + class_name + "/" + name
|
| 35 |
+
f = open(out_path, 'w')
|
| 36 |
+
f.write(out)
|
| 37 |
+
f.close()
|
emo/caption.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import requests
|
| 4 |
+
from lavis.models import load_model_and_preprocess
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
| 8 |
+
model, vis_processors, _ = load_model_and_preprocess(
|
| 9 |
+
name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
path = './emo/image/sadness/'
|
| 13 |
+
filelist = os.listdir(path)
|
| 14 |
+
|
| 15 |
+
for name in filelist:
|
| 16 |
+
print('-----------')
|
| 17 |
+
print(name)
|
| 18 |
+
out_path = './emo/caption/sadness/' + name.split('.')[0] + '.txt'
|
| 19 |
+
f = open(out_path, 'w')
|
| 20 |
+
raw_image = Image.open(path + name)
|
| 21 |
+
image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
|
| 22 |
+
|
| 23 |
+
caption = model.generate({"image": image})
|
| 24 |
+
print(caption[0])
|
| 25 |
+
f.write(caption[0])
|
| 26 |
+
f.close()
|
emo/desktop.ini
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[LocalizedFileNames]
|
| 2 |
+
EmoSet-118K.zip=@EmoSet-118K,0
|
emo/gpt4_conversation.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import json
|
| 3 |
+
import openai
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
#openai.api_key need to change to your own key
|
| 7 |
+
#Search "Need change!!!" in this script
|
| 8 |
+
#Change the number in range
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def generate_chat_completion(messages, model="gpt-4", temperature=1, max_tokens=None):
|
| 12 |
+
openai.api_key ="" #key
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
response = openai.ChatCompletion.create(
|
| 16 |
+
model="gpt-4",
|
| 17 |
+
max_tokens=None,
|
| 18 |
+
temperature=1,
|
| 19 |
+
messages = messages)
|
| 20 |
+
|
| 21 |
+
return response["choices"][0]["message"]["content"]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
#####
|
| 25 |
+
with open('./emo/train.json', 'r') as json_file:
|
| 26 |
+
json_data = json.load(json_file)
|
| 27 |
+
|
| 28 |
+
amusement_data = []
|
| 29 |
+
anger_data = []
|
| 30 |
+
awe_data = []
|
| 31 |
+
contentment_data = []
|
| 32 |
+
disgust_data = []
|
| 33 |
+
excitement_data = []
|
| 34 |
+
fear_data = []
|
| 35 |
+
sadness_data = []
|
| 36 |
+
|
| 37 |
+
for item in json_data:
|
| 38 |
+
category = item[0]
|
| 39 |
+
if category == 'amusement':
|
| 40 |
+
amusement_data.append(item[1].split('/')[2][:-4])
|
| 41 |
+
elif category == 'anger':
|
| 42 |
+
anger_data.append(item[1].split('/')[2][:-4])
|
| 43 |
+
elif category == 'awe':
|
| 44 |
+
awe_data.append(item[1].split('/')[2][:-4])
|
| 45 |
+
elif category == 'contentment':
|
| 46 |
+
contentment_data.append(item[1].split('/')[2][:-4])
|
| 47 |
+
elif category == 'disgust':
|
| 48 |
+
disgust_data.append(item[1].split('/')[2][:-4])
|
| 49 |
+
elif category == 'excitement':
|
| 50 |
+
excitement_data.append(item[1].split('/')[2][:-4])
|
| 51 |
+
elif category == 'fear':
|
| 52 |
+
fear_data.append(item[1].split('/')[2][:-4])
|
| 53 |
+
elif category == 'sadness':
|
| 54 |
+
sadness_data.append(item[1].split('/')[2][:-4])
|
| 55 |
+
|
| 56 |
+
#####
|
| 57 |
+
|
| 58 |
+
prompt_path = "./emo/prompt/conversation.txt"
|
| 59 |
+
with open(prompt_path, 'r', encoding='utf-8') as file:
|
| 60 |
+
content = file.read()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
#Need change!!!
|
| 64 |
+
class_name = 'sadness'
|
| 65 |
+
filelist = sadness_data
|
| 66 |
+
|
| 67 |
+
path = './emo/cap-ano/' + class_name + '/'
|
| 68 |
+
for i in range(0, 1000):
|
| 69 |
+
print(i)
|
| 70 |
+
name = filelist[i]
|
| 71 |
+
caption_path = "./emo/cap-ano/" + class_name + "/" + name + '.txt'
|
| 72 |
+
with open(caption_path, 'r', encoding='utf-8') as file:
|
| 73 |
+
caption = file.read()
|
| 74 |
+
|
| 75 |
+
messages = [
|
| 76 |
+
{"role": "system", "content": content},
|
| 77 |
+
{"role": "user", "content": caption}
|
| 78 |
+
]
|
| 79 |
+
#print(caption)
|
| 80 |
+
|
| 81 |
+
response_text = generate_chat_completion(messages)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
out_path = "./emo/conversation/" + class_name + "/" + name + '.txt'
|
| 85 |
+
f = open(out_path, 'w')
|
| 86 |
+
f.write(response_text)
|
| 87 |
+
f.close()
|
emo/gpt4_reasoning.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import json
|
| 3 |
+
import openai
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
#openai.api_key need to change to your own key
|
| 7 |
+
#Search "Need change!!!" in this script
|
| 8 |
+
#Change the number in range
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def generate_chat_completion(messages, model="gpt-4", temperature=1, max_tokens=None):
|
| 12 |
+
openai.api_key ="" #key;
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
response = openai.ChatCompletion.create(
|
| 16 |
+
model="gpt-4",
|
| 17 |
+
max_tokens=None,
|
| 18 |
+
temperature=1,
|
| 19 |
+
messages = messages)
|
| 20 |
+
|
| 21 |
+
return response["choices"][0]["message"]["content"]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
#####
|
| 25 |
+
with open('./emo/train.json', 'r') as json_file:
|
| 26 |
+
json_data = json.load(json_file)
|
| 27 |
+
|
| 28 |
+
amusement_data = []
|
| 29 |
+
anger_data = []
|
| 30 |
+
awe_data = []
|
| 31 |
+
contentment_data = []
|
| 32 |
+
disgust_data = []
|
| 33 |
+
excitement_data = []
|
| 34 |
+
fear_data = []
|
| 35 |
+
sadness_data = []
|
| 36 |
+
|
| 37 |
+
for item in json_data:
|
| 38 |
+
category = item[0]
|
| 39 |
+
if category == 'amusement':
|
| 40 |
+
amusement_data.append(item[1].split('/')[2][:-4])
|
| 41 |
+
elif category == 'anger':
|
| 42 |
+
anger_data.append(item[1].split('/')[2][:-4])
|
| 43 |
+
elif category == 'awe':
|
| 44 |
+
awe_data.append(item[1].split('/')[2][:-4])
|
| 45 |
+
elif category == 'contentment':
|
| 46 |
+
contentment_data.append(item[1].split('/')[2][:-4])
|
| 47 |
+
elif category == 'disgust':
|
| 48 |
+
disgust_data.append(item[1].split('/')[2][:-4])
|
| 49 |
+
elif category == 'excitement':
|
| 50 |
+
excitement_data.append(item[1].split('/')[2][:-4])
|
| 51 |
+
elif category == 'fear':
|
| 52 |
+
fear_data.append(item[1].split('/')[2][:-4])
|
| 53 |
+
elif category == 'sadness':
|
| 54 |
+
sadness_data.append(item[1].split('/')[2][:-4])
|
| 55 |
+
|
| 56 |
+
#####
|
| 57 |
+
|
| 58 |
+
prompt_path = "./emo/prompt/reasoning.txt"
|
| 59 |
+
with open(prompt_path, 'r', encoding='utf-8') as file:
|
| 60 |
+
content = file.read()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
#Need change!!!
|
| 64 |
+
class_name = 'fear'
|
| 65 |
+
filelist = fear_data
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
path = './emo/cap-ano/' + class_name + '/'
|
| 69 |
+
for i in range(0,100):
|
| 70 |
+
print(i)
|
| 71 |
+
name = filelist[i]
|
| 72 |
+
caption_path = "./emo/cap-ano/" + class_name + "/" + name + '.txt'
|
| 73 |
+
with open(caption_path, 'r', encoding='utf-8') as file:
|
| 74 |
+
caption = file.read()
|
| 75 |
+
|
| 76 |
+
messages = [
|
| 77 |
+
{"role": "system", "content": content},
|
| 78 |
+
{"role": "user", "content": caption}
|
| 79 |
+
]
|
| 80 |
+
# print(caption)
|
| 81 |
+
# assert 0
|
| 82 |
+
|
| 83 |
+
response_text = generate_chat_completion(messages)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
out_path = "./emo/reasoning/" + class_name + "/" + name + '.txt'
|
| 87 |
+
f = open(out_path, 'w')
|
| 88 |
+
f.write(response_text)
|
| 89 |
+
f.close()
|
emo/prompt/conversation.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are an AI visual assistant, and you are seeing a single image. What you see are provided with one caption and some emotion related attributes, describing the same image you are looking at. Answer all questions as you are seeing the image.
|
| 2 |
+
The range of brightness is from 0 (darkest) to 1 (brightest), and the range of colorfulness is from 0 (black-and-white) to 1 (the most colorful).
|
| 3 |
+
|
| 4 |
+
Design two questions for a conversation between you and a person asking about this photo. The answers should be in a tone that a visual AI assistant is seeing the image and answering the question.
|
| 5 |
+
Ask diverse questions and give corresponding answers.
|
| 6 |
+
|
| 7 |
+
Include questions asking about the visual content of the image, including the object types, object actions, relationship among objects, etc. Only include questions that have definite answers:
|
| 8 |
+
(1) one can see the content in the image that the question asks about and can answer confidently;
|
| 9 |
+
(2) one can determine confidently from the image that it is not in the image.
|
| 10 |
+
Do not ask any question that cannot be answered confidently.
|
| 11 |
+
Please answer with the format
|
| 12 |
+
Question:
|
| 13 |
+
Answer:
|
| 14 |
+
|
| 15 |
+
Also include one complex question that is relevant to the content in the image, for example, asking about background knowledge of the objects in the image, asking to discuss about events happening in the image, etc. Again, do not ask about uncertain details.
|
| 16 |
+
Provide detailed answers when answering complex questions. For example, give detailed examples or reasoning steps to make the content more convincing and well-organized. You can include multiple paragraphs if necessary.
|
emo/prompt/description.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are an AI visual assistant that can analyze a single image. What you see are provided with one caption and some emotion related attributes, describing the same image you are looking at.
|
| 2 |
+
|
| 3 |
+
Using the provided caption and attributes, describe the scene in a detailed manner.
|
| 4 |
+
|
| 5 |
+
When using the information from the caption and attributes, directly explain the scene, and do not mention that the information source is the caption or the attributes. Always answer as if you are directly looking at the image.
|
| 6 |
+
Please answer with the format
|
| 7 |
+
Question:
|
| 8 |
+
Answer:
|
emo/prompt/reasoning.txt
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are an AI visual assistant that can analyze a single image. You receive one caption and some emotion related attributes, describing the same image you are looking at.
|
| 2 |
+
|
| 3 |
+
The task is to use the provided caption and attributes, create a plausible question about the image, and provide the answer in detail.
|
| 4 |
+
|
| 5 |
+
Create one complex question beyond describing the scene.
|
| 6 |
+
To answer such question, one should require first understanding the visual content, then based on the background knowledge or reasoning, either explain why the things are happening that way, or provide guides and help to user's request. Make the question challenging by not including the visual content details in the question so that the user needs to reason about that first.
|
| 7 |
+
|
| 8 |
+
When using the information from the caption and attributes, directly explain the scene, and do not mention that the information source is the caption or the attributes. Always answer as if you are directly looking at the image.
|
| 9 |
+
|
| 10 |
+
Following are examples for your information:
|
| 11 |
+
|
| 12 |
+
caption & attributes:
|
| 13 |
+
|
| 14 |
+
a young girl wearing a cheerleading outfit at a restaurant
|
| 15 |
+
|
| 16 |
+
emotion: amusement
|
| 17 |
+
brightness: 0.5
|
| 18 |
+
colorfulness: 0.6
|
| 19 |
+
facial_expression: happy
|
| 20 |
+
|
| 21 |
+
Question: Considering the dining location and the girl's attire, what type of event might have led her to be in this situation, and why she could be feeling amused?
|
| 22 |
+
|
| 23 |
+
Answer: It is probable that the girl is coming from a cheerleading competition or performance since she's dressed in a cheerleading outfit. This could be a school game, cheer clubs event or community celebration where cheerleading was part of the festivities. Often such events foster a sense of camaraderie and joy - elements that could contribute to her amusement. Perhaps her team was successful, or maybe she simply enjoyed performing. Moreover, it is common for teams to celebrate together after events, which would explain their choice to go to a restaurant. Eating together becomes part of the shared experience of the day. Her happy facial expression implies that she is enjoying the current moment, possibly sharing exciting recounts of the day or joking around with her teammates or family. Despite the moderate brightness and colorfulness of the restaurant, the overall mood is upbeat and cheerful.
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
caption & attributes:
|
| 27 |
+
a figurine of an asian woman riding on top of a yellow machine
|
| 28 |
+
|
| 29 |
+
emotion: amusement
|
| 30 |
+
brightness: 0.8
|
| 31 |
+
colorfulness: 0.7
|
| 32 |
+
|
| 33 |
+
Question: Given the peculiar nature of the figurine with an Asian woman on a yellow machine, what cultural or pop culture references could this object possibly be alluding to which may serve as a source of amusement?
|
| 34 |
+
|
| 35 |
+
Answer: The figurine could be indicative of various cultural or pop culture elements. For instance, considering the bright and colorful nature of the figurine, it might be referencing the vibrant Asian pop culture. The yellow machine may be a nod towards the popularity of mecha (giant robots) in Asian cartoons and anime. The woman riding the machine might be a playful take on the trope of strong female characters seen in many of these works. As such, the amusement might arise from the recognition of these references and the whimsical portrayal of these themes in a form as innocent as a figurine. Sometimes, these figurines are designed in a playful and exaggerated manner to depict everyday or fictional scenarios with a touch of humor, which might also lead to the experienced amusement. The bright colors and the high brightness level add to the cheerfulness and comical nature of the scene.
|
| 36 |
+
|
| 37 |
+
caption & attributes:
|
| 38 |
+
a man standing in front of a display of powerpuff dolls
|
| 39 |
+
|
| 40 |
+
emotion: amusement
|
| 41 |
+
brightness: 0.6
|
| 42 |
+
colorfulness: 0.8
|
| 43 |
+
object: ['Toy']
|
| 44 |
+
facial_expression: happy
|
| 45 |
+
|
| 46 |
+
Question: Taking into account the scene with the man and the display of Powerpuff dolls, why might this scenario be amusing and evoke happiness in him, considering the age and gender stereotypes often associated with such collectibles?
|
| 47 |
+
|
| 48 |
+
Answer: This image could be amusing due to the contrast between the man and the Powerpuff dolls, which are typically associated with a younger, female demographic. The man's amused expression and the vibrant colors and brightness of the scene suggest a playful and light-hearted atmosphere. Perhaps the man used to watch the Powerpuff Girls show when he was younger and the dolls took him down the memory lane, sparking a sense of nostalgia, a factor that could certainly contribute to his happiness. Maybe he is a collector who appreciates these toys for their design and the cultural significance they hold. Alternatively, he could be a parent or an uncle shopping for a young relative and the sight of the familiar characters from his past brought a smile to his face. Regardless, the image challenges the conventional stereotype regarding who should enjoy toys and cartoon characters, showing that amusement can be found in unexpected places and situations.
|
| 49 |
+
|
emo/test.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
emo/train.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b3043ba41320d9309513d9661e64c1f89d2c40370302596718b774f5917e772f
|
| 3 |
+
size 10750390
|
emo/val.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Flask and Web Framework
|
| 2 |
+
Flask==2.3.3
|
| 3 |
+
gunicorn==21.2.0
|
| 4 |
+
|
| 5 |
+
# Core ML Libraries
|
| 6 |
+
torch>=1.10.0
|
| 7 |
+
torchvision>=0.11.0
|
| 8 |
+
transformers>=4.28.0
|
| 9 |
+
pillow>=10.0.0
|
| 10 |
+
|
| 11 |
+
# Image Processing
|
| 12 |
+
opencv-python-headless>=4.5.0
|
| 13 |
+
|
| 14 |
+
# LAVIS dependencies (for BLIP2)
|
| 15 |
+
salesforce-lavis
|
| 16 |
+
omegaconf>=2.3.0
|
| 17 |
+
|
| 18 |
+
# Data handling
|
| 19 |
+
numpy>=1.24.0
|
| 20 |
+
pandas>=2.0.0
|
| 21 |
+
|
| 22 |
+
# Other utilities
|
| 23 |
+
requests>=2.31.0
|
| 24 |
+
tqdm>=4.66.0
|
| 25 |
+
safetensors>=0.4.0
|
| 26 |
+
|
| 27 |
+
# For Hugging Face compatibility
|
| 28 |
+
huggingface-hub>=0.20.0
|
| 29 |
+
|
| 30 |
+
# Additional utilities that might be needed
|
| 31 |
+
einops>=0.7.0
|
| 32 |
+
timm>=0.4.12
|
| 33 |
+
sentencepiece>=0.1.99
|
| 34 |
+
|
| 35 |
+
# For better logging
|
| 36 |
+
loguru
|
| 37 |
+
|
| 38 |
+
# For environment variables
|
| 39 |
+
python-dotenv
|
requirements_emo.txt
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.1.0
|
| 2 |
+
altair==5.1.2
|
| 3 |
+
annotated-types==0.6.0
|
| 4 |
+
antlr4-python3-runtime==4.9.3
|
| 5 |
+
asttokens==2.4.1
|
| 6 |
+
attrs==23.1.0
|
| 7 |
+
backcall==0.2.0
|
| 8 |
+
backports.zoneinfo==0.2.1
|
| 9 |
+
bleach==6.1.0
|
| 10 |
+
blinker==1.7.0
|
| 11 |
+
blis==0.7.11
|
| 12 |
+
braceexpand==0.1.7
|
| 13 |
+
cachetools==5.3.3
|
| 14 |
+
catalogue==2.0.10
|
| 15 |
+
certifi==2023.7.22
|
| 16 |
+
cfgv==3.4.0
|
| 17 |
+
charset-normalizer==3.3.2
|
| 18 |
+
click==8.1.7
|
| 19 |
+
cloudpathlib==0.16.0
|
| 20 |
+
confection==0.1.3
|
| 21 |
+
contexttimer==0.3.3
|
| 22 |
+
contourpy==1.1.1
|
| 23 |
+
cycler==0.12.1
|
| 24 |
+
cymem==2.0.8
|
| 25 |
+
Cython==3.0.10
|
| 26 |
+
decorator==5.1.1
|
| 27 |
+
decord==0.6.0
|
| 28 |
+
diffusers==0.16.0
|
| 29 |
+
distlib==0.3.7
|
| 30 |
+
einops==0.7.0
|
| 31 |
+
executing==2.0.1
|
| 32 |
+
fairscale==0.4.4
|
| 33 |
+
filelock==3.13.1
|
| 34 |
+
fonttools==4.44.0
|
| 35 |
+
fsspec==2023.10.0
|
| 36 |
+
ftfy==6.1.1
|
| 37 |
+
gitdb==4.0.11
|
| 38 |
+
GitPython==3.1.40
|
| 39 |
+
google-auth==2.29.0
|
| 40 |
+
google-auth-oauthlib==1.0.0
|
| 41 |
+
grpcio==1.62.1
|
| 42 |
+
huggingface-hub==0.20.2
|
| 43 |
+
identify==2.5.31
|
| 44 |
+
idna==3.4
|
| 45 |
+
imageio==2.32.0
|
| 46 |
+
importlib-resources==6.1.1
|
| 47 |
+
importlib_metadata==7.1.0
|
| 48 |
+
iopath==0.1.10
|
| 49 |
+
ipython==8.12.3
|
| 50 |
+
jedi==0.19.1
|
| 51 |
+
Jinja2==3.1.2
|
| 52 |
+
jsonschema==4.19.2
|
| 53 |
+
jsonschema-specifications==2023.7.1
|
| 54 |
+
kaggle==1.5.16
|
| 55 |
+
kiwisolver==1.4.5
|
| 56 |
+
langcodes==3.3.0
|
| 57 |
+
lazy_loader==0.3
|
| 58 |
+
Markdown==3.6
|
| 59 |
+
markdown-it-py==3.0.0
|
| 60 |
+
MarkupSafe==2.1.5
|
| 61 |
+
matplotlib==3.7.3
|
| 62 |
+
matplotlib-inline==0.1.7
|
| 63 |
+
mdurl==0.1.2
|
| 64 |
+
mpmath==1.3.0
|
| 65 |
+
murmurhash==1.0.10
|
| 66 |
+
networkx==3.1
|
| 67 |
+
nodeenv==1.8.0
|
| 68 |
+
numpy==1.24.4
|
| 69 |
+
nvidia-cublas-cu12==12.1.3.1
|
| 70 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
| 71 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
| 72 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
| 73 |
+
nvidia-cudnn-cu12==8.9.2.26
|
| 74 |
+
nvidia-cufft-cu12==11.0.2.54
|
| 75 |
+
nvidia-curand-cu12==10.3.2.106
|
| 76 |
+
nvidia-cusolver-cu12==11.4.5.107
|
| 77 |
+
nvidia-cusparse-cu12==12.1.0.106
|
| 78 |
+
nvidia-nccl-cu12==2.18.1
|
| 79 |
+
nvidia-nvjitlink-cu12==12.3.52
|
| 80 |
+
nvidia-nvtx-cu12==12.1.105
|
| 81 |
+
omegaconf==2.3.0
|
| 82 |
+
opencv-python-headless==4.5.5.64
|
| 83 |
+
opendatasets==0.1.22
|
| 84 |
+
packaging==24.0
|
| 85 |
+
pandas==2.0.3
|
| 86 |
+
parso==0.8.4
|
| 87 |
+
pexpect==4.8.0
|
| 88 |
+
pickleshare==0.7.5
|
| 89 |
+
pillow==10.3.0
|
| 90 |
+
pkgutil_resolve_name==1.3.10
|
| 91 |
+
plotly==5.18.0
|
| 92 |
+
portalocker==2.8.2
|
| 93 |
+
pre-commit==3.5.0
|
| 94 |
+
preshed==3.0.9
|
| 95 |
+
prompt-toolkit==3.0.43
|
| 96 |
+
protobuf==5.26.1
|
| 97 |
+
ptyprocess==0.7.0
|
| 98 |
+
pure-eval==0.2.2
|
| 99 |
+
pyarrow==14.0.1
|
| 100 |
+
pycocoevalcap==1.2
|
| 101 |
+
pycocotools==2.0.7
|
| 102 |
+
pydantic==2.5.0
|
| 103 |
+
pydantic_core==2.14.1
|
| 104 |
+
pydeck==0.8.1b0
|
| 105 |
+
Pygments==2.17.2
|
| 106 |
+
pyparsing==3.1.1
|
| 107 |
+
python-dateutil==2.8.2
|
| 108 |
+
python-magic==0.4.27
|
| 109 |
+
python-slugify==8.0.1
|
| 110 |
+
pytz==2024.1
|
| 111 |
+
PyWavelets==1.4.1
|
| 112 |
+
PyYAML==6.0.1
|
| 113 |
+
referencing==0.30.2
|
| 114 |
+
regex==2023.10.3
|
| 115 |
+
requests==2.31.0
|
| 116 |
+
requests-oauthlib==2.0.0
|
| 117 |
+
rich==13.6.0
|
| 118 |
+
rpds-py==0.12.0
|
| 119 |
+
rsa==4.9
|
| 120 |
+
safetensors==0.4.0
|
| 121 |
+
scikit-image==0.21.0
|
| 122 |
+
scipy==1.10.1
|
| 123 |
+
sentencepiece==0.1.99
|
| 124 |
+
six==1.16.0
|
| 125 |
+
smart-open==6.4.0
|
| 126 |
+
smmap==5.0.1
|
| 127 |
+
spacy==3.7.2
|
| 128 |
+
spacy-legacy==3.0.12
|
| 129 |
+
spacy-loggers==1.0.5
|
| 130 |
+
srsly==2.4.8
|
| 131 |
+
stack-data==0.6.3
|
| 132 |
+
streamlit==1.28.2
|
| 133 |
+
sympy==1.12
|
| 134 |
+
tenacity==8.2.3
|
| 135 |
+
tensorboard==2.14.0
|
| 136 |
+
tensorboard-data-server==0.7.2
|
| 137 |
+
text-unidecode==1.3
|
| 138 |
+
thinc==8.2.1
|
| 139 |
+
tifffile==2023.7.10
|
| 140 |
+
timm==0.4.12
|
| 141 |
+
tokenizers==0.13.3
|
| 142 |
+
toml==0.10.2
|
| 143 |
+
toolz==0.12.0
|
| 144 |
+
torch==1.10.2+cu111
|
| 145 |
+
torch-tb-profiler==0.4.3
|
| 146 |
+
torchaudio==0.10.2+cu111
|
| 147 |
+
torchvision==0.11.3+cu111
|
| 148 |
+
tqdm==4.66.1
|
| 149 |
+
traitlets==5.14.3
|
| 150 |
+
transformers==4.28.0
|
| 151 |
+
triton==2.1.0
|
| 152 |
+
typer==0.9.0
|
| 153 |
+
typing_extensions==4.11.0
|
| 154 |
+
tzdata==2024.1
|
| 155 |
+
tzlocal==5.2
|
| 156 |
+
urllib3==1.26.18
|
| 157 |
+
validators==0.22.0
|
| 158 |
+
virtualenv==20.24.6
|
| 159 |
+
wasabi==1.1.2
|
| 160 |
+
watchdog==3.0.0
|
| 161 |
+
wcwidth==0.2.13
|
| 162 |
+
weasel==0.3.4
|
| 163 |
+
webdataset==0.2.75
|
| 164 |
+
webencodings==0.5.1
|
| 165 |
+
Werkzeug==3.0.1
|
| 166 |
+
zipp==3.17.0
|
requirements_lavis.txt
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
altair==5.4.1
|
| 2 |
+
annotated-types==0.7.0
|
| 3 |
+
antlr4-python3-runtime==4.9.3
|
| 4 |
+
asttokens==2.4.1
|
| 5 |
+
attrs==24.2.0
|
| 6 |
+
backcall==0.2.0
|
| 7 |
+
bleach==6.1.0
|
| 8 |
+
blinker==1.8.2
|
| 9 |
+
blis==0.7.11
|
| 10 |
+
braceexpand==0.1.7
|
| 11 |
+
cachetools==5.5.0
|
| 12 |
+
catalogue==2.0.10
|
| 13 |
+
certifi==2024.8.30
|
| 14 |
+
cfgv==3.4.0
|
| 15 |
+
charset-normalizer==3.3.2
|
| 16 |
+
click==8.1.7
|
| 17 |
+
cloudpathlib==0.19.0
|
| 18 |
+
cmake==3.30.3
|
| 19 |
+
confection==0.1.5
|
| 20 |
+
contexttimer==0.3.3
|
| 21 |
+
contourpy==1.1.1
|
| 22 |
+
cycler==0.12.1
|
| 23 |
+
cymem==2.0.8
|
| 24 |
+
decorator==5.1.1
|
| 25 |
+
decord==0.6.0
|
| 26 |
+
diffusers==0.16.0
|
| 27 |
+
distlib==0.3.8
|
| 28 |
+
einops==0.8.0
|
| 29 |
+
executing==2.1.0
|
| 30 |
+
fairscale==0.4.4
|
| 31 |
+
filelock==3.16.1
|
| 32 |
+
fonttools==4.53.1
|
| 33 |
+
fsspec==2024.9.0
|
| 34 |
+
ftfy==6.2.3
|
| 35 |
+
gitdb==4.0.11
|
| 36 |
+
GitPython==3.1.43
|
| 37 |
+
huggingface-hub==0.25.0
|
| 38 |
+
identify==2.6.1
|
| 39 |
+
idna==3.10
|
| 40 |
+
imageio==2.35.1
|
| 41 |
+
importlib_metadata==8.5.0
|
| 42 |
+
importlib_resources==6.4.5
|
| 43 |
+
iopath==0.1.10
|
| 44 |
+
ipython==8.12.3
|
| 45 |
+
jedi==0.19.1
|
| 46 |
+
Jinja2==3.1.4
|
| 47 |
+
jsonschema==4.23.0
|
| 48 |
+
jsonschema-specifications==2023.12.1
|
| 49 |
+
kaggle==1.6.17
|
| 50 |
+
kiwisolver==1.4.7
|
| 51 |
+
langcodes==3.4.0
|
| 52 |
+
language_data==1.2.0
|
| 53 |
+
lazy_loader==0.4
|
| 54 |
+
lit==18.1.8
|
| 55 |
+
marisa-trie==1.2.0
|
| 56 |
+
markdown-it-py==3.0.0
|
| 57 |
+
MarkupSafe==2.1.5
|
| 58 |
+
matplotlib==3.7.5
|
| 59 |
+
matplotlib-inline==0.1.7
|
| 60 |
+
mdurl==0.1.2
|
| 61 |
+
mpmath==1.3.0
|
| 62 |
+
murmurhash==1.0.10
|
| 63 |
+
narwhals==1.8.2
|
| 64 |
+
networkx==3.1
|
| 65 |
+
nodeenv==1.9.1
|
| 66 |
+
numpy==1.24.4
|
| 67 |
+
nvidia-cublas-cu11==11.10.3.66
|
| 68 |
+
nvidia-cuda-cupti-cu11==11.7.101
|
| 69 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
| 70 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
| 71 |
+
nvidia-cudnn-cu11==8.5.0.96
|
| 72 |
+
nvidia-cufft-cu11==10.9.0.58
|
| 73 |
+
nvidia-curand-cu11==10.2.10.91
|
| 74 |
+
nvidia-cusolver-cu11==11.4.0.1
|
| 75 |
+
nvidia-cusparse-cu11==11.7.4.91
|
| 76 |
+
nvidia-nccl-cu11==2.14.3
|
| 77 |
+
nvidia-nvtx-cu11==11.7.91
|
| 78 |
+
omegaconf==2.3.0
|
| 79 |
+
opencv-python-headless==4.5.5.64
|
| 80 |
+
opendatasets==0.1.22
|
| 81 |
+
packaging==24.1
|
| 82 |
+
pandas==2.0.3
|
| 83 |
+
parso==0.8.4
|
| 84 |
+
pexpect==4.9.0
|
| 85 |
+
pickleshare==0.7.5
|
| 86 |
+
pillow==10.4.0
|
| 87 |
+
pkgutil_resolve_name==1.3.10
|
| 88 |
+
platformdirs==4.3.6
|
| 89 |
+
plotly==5.24.1
|
| 90 |
+
portalocker==2.10.1
|
| 91 |
+
pre-commit==3.5.0
|
| 92 |
+
preshed==3.0.9
|
| 93 |
+
prompt_toolkit==3.0.47
|
| 94 |
+
protobuf==5.28.2
|
| 95 |
+
ptyprocess==0.7.0
|
| 96 |
+
pure_eval==0.2.3
|
| 97 |
+
pyarrow==17.0.0
|
| 98 |
+
pycocoevalcap==1.2
|
| 99 |
+
pycocotools==2.0.7
|
| 100 |
+
pydantic==2.9.2
|
| 101 |
+
pydantic_core==2.23.4
|
| 102 |
+
pydeck==0.9.1
|
| 103 |
+
Pygments==2.18.0
|
| 104 |
+
pyparsing==3.1.4
|
| 105 |
+
python-dateutil==2.9.0.post0
|
| 106 |
+
python-magic==0.4.27
|
| 107 |
+
python-slugify==8.0.4
|
| 108 |
+
pytz==2024.2
|
| 109 |
+
PyWavelets==1.4.1
|
| 110 |
+
PyYAML==6.0.2
|
| 111 |
+
referencing==0.35.1
|
| 112 |
+
regex==2024.9.11
|
| 113 |
+
requests==2.32.3
|
| 114 |
+
rich==13.8.1
|
| 115 |
+
rpds-py==0.20.0
|
| 116 |
+
safetensors==0.4.5
|
| 117 |
+
scikit-image==0.21.0
|
| 118 |
+
scipy==1.10.1
|
| 119 |
+
sentencepiece==0.2.0
|
| 120 |
+
shellingham==1.5.4
|
| 121 |
+
six==1.16.0
|
| 122 |
+
smart-open==7.0.4
|
| 123 |
+
smmap==5.0.1
|
| 124 |
+
spacy==3.7.6
|
| 125 |
+
spacy-legacy==3.0.12
|
| 126 |
+
spacy-loggers==1.0.5
|
| 127 |
+
srsly==2.4.8
|
| 128 |
+
stack-data==0.6.3
|
| 129 |
+
streamlit==1.38.0
|
| 130 |
+
sympy==1.13.3
|
| 131 |
+
tenacity==8.5.0
|
| 132 |
+
text-unidecode==1.3
|
| 133 |
+
thinc==8.2.5
|
| 134 |
+
tifffile==2023.7.10
|
| 135 |
+
timm==0.4.12
|
| 136 |
+
tokenizers==0.13.3
|
| 137 |
+
toml==0.10.2
|
| 138 |
+
torch==2.0.0
|
| 139 |
+
torchaudio==2.0.1
|
| 140 |
+
torchvision==0.15.1
|
| 141 |
+
tornado==6.4.1
|
| 142 |
+
tqdm==4.66.5
|
| 143 |
+
traitlets==5.14.3
|
| 144 |
+
transformers==4.31.0
|
| 145 |
+
triton==2.0.0
|
| 146 |
+
typer==0.12.5
|
| 147 |
+
typing_extensions==4.12.2
|
| 148 |
+
tzdata==2024.1
|
| 149 |
+
urllib3==2.2.3
|
| 150 |
+
virtualenv==20.26.5
|
| 151 |
+
wasabi==1.1.3
|
| 152 |
+
watchdog==4.0.2
|
| 153 |
+
wcwidth==0.2.13
|
| 154 |
+
weasel==0.4.1
|
| 155 |
+
webdataset==0.2.100
|
| 156 |
+
webencodings==0.5.1
|
| 157 |
+
wrapt==1.16.0
|
| 158 |
+
zipp==3.20.2
|
static/css/style.css
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Custom CSS for EmoVIT Application */
|
| 2 |
+
|
| 3 |
+
:root {
|
| 4 |
+
--primary-color: #4f46e5;
|
| 5 |
+
--primary-dark: #3730a3;
|
| 6 |
+
--secondary-color: #06b6d4;
|
| 7 |
+
--success-color: #10b981;
|
| 8 |
+
--warning-color: #f59e0b;
|
| 9 |
+
--danger-color: #ef4444;
|
| 10 |
+
--light-color: #f8fafc;
|
| 11 |
+
--dark-color: #1e293b;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
body {
|
| 15 |
+
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
| 16 |
+
line-height: 1.6;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
/* Background Gradient */
|
| 20 |
+
.bg-gradient-primary {
|
| 21 |
+
background: linear-gradient(135deg, var(--primary-color) 0%, var(--secondary-color) 100%);
|
| 22 |
+
min-height: 100vh;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
/* Custom Card Styling */
|
| 26 |
+
.card {
|
| 27 |
+
border: none;
|
| 28 |
+
box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 10px 10px -5px rgba(0, 0, 0, 0.04);
|
| 29 |
+
transition: all 0.3s ease;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
.card:hover {
|
| 33 |
+
transform: translateY(-2px);
|
| 34 |
+
box-shadow: 0 25px 30px -5px rgba(0, 0, 0, 0.15), 0 15px 15px -5px rgba(0, 0, 0, 0.06);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
/* Button Styling */
|
| 38 |
+
.btn {
|
| 39 |
+
border-radius: 12px;
|
| 40 |
+
font-weight: 600;
|
| 41 |
+
text-transform: uppercase;
|
| 42 |
+
letter-spacing: 0.5px;
|
| 43 |
+
transition: all 0.3s ease;
|
| 44 |
+
border: 2px solid transparent;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
.btn-primary {
|
| 48 |
+
background: linear-gradient(135deg, var(--primary-color), var(--primary-dark));
|
| 49 |
+
border: none;
|
| 50 |
+
box-shadow: 0 4px 15px rgba(79, 70, 229, 0.3);
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
.btn-primary:hover {
|
| 54 |
+
transform: translateY(-2px);
|
| 55 |
+
box-shadow: 0 8px 25px rgba(79, 70, 229, 0.4);
|
| 56 |
+
background: linear-gradient(135deg, var(--primary-dark), var(--primary-color));
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
.btn-outline-primary {
|
| 60 |
+
border-color: var(--primary-color);
|
| 61 |
+
color: var(--primary-color);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
.btn-outline-primary:hover {
|
| 65 |
+
background: var(--primary-color);
|
| 66 |
+
border-color: var(--primary-color);
|
| 67 |
+
transform: translateY(-2px);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
/* Form Styling */
|
| 71 |
+
.form-control {
|
| 72 |
+
border-radius: 10px;
|
| 73 |
+
border: 2px solid #e2e8f0;
|
| 74 |
+
padding: 12px 16px;
|
| 75 |
+
font-size: 16px;
|
| 76 |
+
transition: all 0.3s ease;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
.form-control:focus {
|
| 80 |
+
border-color: var(--primary-color);
|
| 81 |
+
box-shadow: 0 0 0 3px rgba(79, 70, 229, 0.1);
|
| 82 |
+
transform: translateY(-1px);
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
.form-control-lg {
|
| 86 |
+
padding: 16px 20px;
|
| 87 |
+
font-size: 18px;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
/* Alert Styling */
|
| 91 |
+
.alert {
|
| 92 |
+
border-radius: 12px;
|
| 93 |
+
border: none;
|
| 94 |
+
padding: 20px;
|
| 95 |
+
font-weight: 500;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
.alert-info {
|
| 99 |
+
background: linear-gradient(135deg, #e0f2fe, #b3e5fc);
|
| 100 |
+
color: #0277bd;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
.alert-success {
|
| 104 |
+
background: linear-gradient(135deg, #e8f5e8, #c8e6c9);
|
| 105 |
+
color: #2e7d32;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
.alert-danger {
|
| 109 |
+
background: linear-gradient(135deg, #ffebee, #ffcdd2);
|
| 110 |
+
color: #c62828;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
/* Loading Spinner */
|
| 114 |
+
.spinner-border {
|
| 115 |
+
width: 3rem;
|
| 116 |
+
height: 3rem;
|
| 117 |
+
border-width: 0.3em;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
/* Image Preview */
|
| 121 |
+
#previewImage {
|
| 122 |
+
max-width: 100%;
|
| 123 |
+
border-radius: 10px;
|
| 124 |
+
transition: transform 0.3s ease;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
#previewImage:hover {
|
| 128 |
+
transform: scale(1.02);
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
/* Upload Section */
|
| 132 |
+
.upload-section {
|
| 133 |
+
position: relative;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
.upload-section::before {
|
| 137 |
+
content: '';
|
| 138 |
+
position: absolute;
|
| 139 |
+
top: -20px;
|
| 140 |
+
left: 50%;
|
| 141 |
+
transform: translateX(-50%);
|
| 142 |
+
width: 60px;
|
| 143 |
+
height: 4px;
|
| 144 |
+
background: linear-gradient(90deg, var(--primary-color), var(--secondary-color));
|
| 145 |
+
border-radius: 2px;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
/* Results Section */
|
| 149 |
+
.emotion-result {
|
| 150 |
+
text-align: center;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
.emotion-result h4 {
|
| 154 |
+
margin-bottom: 20px;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
/* Icon Styling */
|
| 158 |
+
.fas, .far {
|
| 159 |
+
transition: transform 0.3s ease;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
.btn:hover .fas,
|
| 163 |
+
.btn:hover .far {
|
| 164 |
+
transform: scale(1.1);
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
/* Card Headers */
|
| 168 |
+
.card-header {
|
| 169 |
+
background: linear-gradient(135deg, #f8fafc, #e2e8f0) !important;
|
| 170 |
+
border-bottom: 2px solid #e2e8f0;
|
| 171 |
+
border-radius: 12px 12px 0 0 !important;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
/* Responsive Design */
|
| 175 |
+
@media (max-width: 768px) {
|
| 176 |
+
.container {
|
| 177 |
+
padding: 15px;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
.card-body {
|
| 181 |
+
padding: 30px 20px;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
.display-4 {
|
| 185 |
+
font-size: 2.5rem;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
.btn-lg {
|
| 189 |
+
padding: 12px 30px;
|
| 190 |
+
font-size: 16px;
|
| 191 |
+
}
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
/* Animations */
|
| 195 |
+
@keyframes fadeIn {
|
| 196 |
+
from {
|
| 197 |
+
opacity: 0;
|
| 198 |
+
transform: translateY(20px);
|
| 199 |
+
}
|
| 200 |
+
to {
|
| 201 |
+
opacity: 1;
|
| 202 |
+
transform: translateY(0);
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
.card {
|
| 207 |
+
animation: fadeIn 0.6s ease-out;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
/* Custom Utilities */
|
| 211 |
+
.rounded-4 {
|
| 212 |
+
border-radius: 1.5rem !important;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
.text-white-50 {
|
| 216 |
+
color: rgba(255, 255, 255, 0.75) !important;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
/* File Input Styling */
|
| 220 |
+
input[type="file"] {
|
| 221 |
+
cursor: pointer;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
input[type="file"]::-webkit-file-upload-button {
|
| 225 |
+
background: linear-gradient(135deg, var(--primary-color), var(--primary-dark));
|
| 226 |
+
color: white;
|
| 227 |
+
border: none;
|
| 228 |
+
padding: 8px 16px;
|
| 229 |
+
border-radius: 8px;
|
| 230 |
+
margin-right: 10px;
|
| 231 |
+
cursor: pointer;
|
| 232 |
+
font-weight: 600;
|
| 233 |
+
transition: all 0.3s ease;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
input[type="file"]::-webkit-file-upload-button:hover {
|
| 237 |
+
transform: translateY(-1px);
|
| 238 |
+
box-shadow: 0 4px 12px rgba(79, 70, 229, 0.3);
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
/* Progress Bar (for future use) */
|
| 242 |
+
.progress {
|
| 243 |
+
height: 8px;
|
| 244 |
+
border-radius: 4px;
|
| 245 |
+
background-color: #e2e8f0;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
.progress-bar {
|
| 249 |
+
background: linear-gradient(90deg, var(--primary-color), var(--secondary-color));
|
| 250 |
+
border-radius: 4px;
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
/* Tooltip Styling */
|
| 254 |
+
.tooltip {
|
| 255 |
+
font-size: 12px;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
.tooltip-inner {
|
| 259 |
+
background-color: var(--dark-color);
|
| 260 |
+
border-radius: 6px;
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
/* Focus States for Accessibility */
|
| 264 |
+
.btn:focus,
|
| 265 |
+
.form-control:focus {
|
| 266 |
+
outline: none;
|
| 267 |
+
box-shadow: 0 0 0 3px rgba(79, 70, 229, 0.25);
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
/* Error Styling */
|
| 271 |
+
.is-invalid {
|
| 272 |
+
border-color: var(--danger-color);
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
.invalid-feedback {
|
| 276 |
+
color: var(--danger-color);
|
| 277 |
+
font-weight: 500;
|
| 278 |
+
}
|
templates/index.html
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>EmoVIT - Emotion Detection</title>
|
| 7 |
+
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
|
| 8 |
+
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
|
| 9 |
+
<link href="{{ url_for('static', filename='css/style.css') }}" rel="stylesheet">
|
| 10 |
+
</head>
|
| 11 |
+
<body>
|
| 12 |
+
<div class="container-fluid bg-gradient-primary min-vh-100">
|
| 13 |
+
<div class="container py-5">
|
| 14 |
+
<div class="row justify-content-center">
|
| 15 |
+
<div class="col-lg-8">
|
| 16 |
+
<!-- Header -->
|
| 17 |
+
<div class="text-center mb-5">
|
| 18 |
+
<h1 class="display-4 text-white mb-3">
|
| 19 |
+
<i class="fas fa-smile text-warning me-3"></i>
|
| 20 |
+
EmoVIT
|
| 21 |
+
</h1>
|
| 22 |
+
<p class="lead text-white-50">
|
| 23 |
+
AI-Powered Emotion Detection using BLIP2-Vicuna
|
| 24 |
+
</p>
|
| 25 |
+
</div>
|
| 26 |
+
|
| 27 |
+
<!-- Main Card -->
|
| 28 |
+
<div class="card shadow-lg border-0 rounded-4">
|
| 29 |
+
<div class="card-body p-5">
|
| 30 |
+
<!-- Upload Section -->
|
| 31 |
+
<div class="upload-section mb-4">
|
| 32 |
+
<h3 class="text-center mb-4">
|
| 33 |
+
<i class="fas fa-upload text-primary me-2"></i>
|
| 34 |
+
Upload Image for Emotion Analysis
|
| 35 |
+
</h3>
|
| 36 |
+
|
| 37 |
+
<form id="uploadForm" enctype="multipart/form-data">
|
| 38 |
+
<!-- Custom Prompt -->
|
| 39 |
+
<div class="mb-4">
|
| 40 |
+
<label for="promptInput" class="form-label fw-bold">
|
| 41 |
+
<i class="fas fa-comment-dots me-2"></i>
|
| 42 |
+
Custom Prompt (Optional)
|
| 43 |
+
</label>
|
| 44 |
+
<input type="text"
|
| 45 |
+
class="form-control form-control-lg"
|
| 46 |
+
id="promptInput"
|
| 47 |
+
name="prompt"
|
| 48 |
+
placeholder="What emotion is shown in this image?"
|
| 49 |
+
value="What emotion is shown in this image?">
|
| 50 |
+
</div>
|
| 51 |
+
|
| 52 |
+
<!-- File Upload -->
|
| 53 |
+
<div class="mb-4">
|
| 54 |
+
<label for="imageInput" class="form-label fw-bold">
|
| 55 |
+
<i class="fas fa-image me-2"></i>
|
| 56 |
+
Select Image
|
| 57 |
+
</label>
|
| 58 |
+
<input type="file"
|
| 59 |
+
class="form-control form-control-lg"
|
| 60 |
+
id="imageInput"
|
| 61 |
+
name="image"
|
| 62 |
+
accept="image/*"
|
| 63 |
+
required>
|
| 64 |
+
</div>
|
| 65 |
+
|
| 66 |
+
<!-- Submit Button -->
|
| 67 |
+
<div class="text-center">
|
| 68 |
+
<button type="submit"
|
| 69 |
+
class="btn btn-primary btn-lg px-5 py-3"
|
| 70 |
+
id="analyzeBtn">
|
| 71 |
+
<i class="fas fa-brain me-2"></i>
|
| 72 |
+
Analyze Emotion
|
| 73 |
+
</button>
|
| 74 |
+
</div>
|
| 75 |
+
</form>
|
| 76 |
+
</div>
|
| 77 |
+
|
| 78 |
+
<!-- Loading Spinner -->
|
| 79 |
+
<div id="loadingSpinner" class="text-center d-none">
|
| 80 |
+
<div class="spinner-border text-primary" role="status">
|
| 81 |
+
<span class="visually-hidden">Loading...</span>
|
| 82 |
+
</div>
|
| 83 |
+
<p class="mt-3 text-muted">Analyzing emotion...</p>
|
| 84 |
+
</div>
|
| 85 |
+
|
| 86 |
+
<!-- Results Section -->
|
| 87 |
+
<div id="resultsSection" class="d-none">
|
| 88 |
+
<hr class="my-5">
|
| 89 |
+
<h3 class="text-center mb-4">
|
| 90 |
+
<i class="fas fa-chart-line text-success me-2"></i>
|
| 91 |
+
Analysis Results
|
| 92 |
+
</h3>
|
| 93 |
+
|
| 94 |
+
<div class="row">
|
| 95 |
+
<!-- Image Preview -->
|
| 96 |
+
<div class="col-md-6 mb-4">
|
| 97 |
+
<div class="card h-100">
|
| 98 |
+
<div class="card-header bg-light">
|
| 99 |
+
<h5 class="mb-0">
|
| 100 |
+
<i class="fas fa-image me-2"></i>
|
| 101 |
+
Uploaded Image
|
| 102 |
+
</h5>
|
| 103 |
+
</div>
|
| 104 |
+
<div class="card-body text-center">
|
| 105 |
+
<img id="previewImage"
|
| 106 |
+
src=""
|
| 107 |
+
alt="Uploaded image"
|
| 108 |
+
class="img-fluid rounded shadow-sm"
|
| 109 |
+
style="max-height: 300px;">
|
| 110 |
+
</div>
|
| 111 |
+
</div>
|
| 112 |
+
</div>
|
| 113 |
+
|
| 114 |
+
<!-- Results -->
|
| 115 |
+
<div class="col-md-6 mb-4">
|
| 116 |
+
<div class="card h-100">
|
| 117 |
+
<div class="card-header bg-light">
|
| 118 |
+
<h5 class="mb-0">
|
| 119 |
+
<i class="fas fa-brain me-2"></i>
|
| 120 |
+
Detected Emotion
|
| 121 |
+
</h5>
|
| 122 |
+
</div>
|
| 123 |
+
<div class="card-body">
|
| 124 |
+
<div class="alert alert-info" role="alert">
|
| 125 |
+
<strong>Prompt:</strong>
|
| 126 |
+
<p id="usedPrompt" class="mb-0 mt-2"></p>
|
| 127 |
+
</div>
|
| 128 |
+
|
| 129 |
+
<div class="emotion-result">
|
| 130 |
+
<h4 class="text-primary mb-3">
|
| 131 |
+
<i class="fas fa-smile-beam me-2"></i>
|
| 132 |
+
Result:
|
| 133 |
+
</h4>
|
| 134 |
+
<div class="alert alert-success" role="alert">
|
| 135 |
+
<p id="emotionResult" class="mb-0 fs-5 fw-bold"></p>
|
| 136 |
+
</div>
|
| 137 |
+
</div>
|
| 138 |
+
</div>
|
| 139 |
+
</div>
|
| 140 |
+
</div>
|
| 141 |
+
</div>
|
| 142 |
+
|
| 143 |
+
<!-- Try Again Button -->
|
| 144 |
+
<div class="text-center mt-4">
|
| 145 |
+
<button type="button"
|
| 146 |
+
class="btn btn-outline-primary btn-lg"
|
| 147 |
+
id="tryAgainBtn">
|
| 148 |
+
<i class="fas fa-redo me-2"></i>
|
| 149 |
+
Try Another Image
|
| 150 |
+
</button>
|
| 151 |
+
</div>
|
| 152 |
+
</div>
|
| 153 |
+
|
| 154 |
+
<!-- Error Section -->
|
| 155 |
+
<div id="errorSection" class="d-none">
|
| 156 |
+
<div class="alert alert-danger" role="alert">
|
| 157 |
+
<h4 class="alert-heading">
|
| 158 |
+
<i class="fas fa-exclamation-triangle me-2"></i>
|
| 159 |
+
Error
|
| 160 |
+
</h4>
|
| 161 |
+
<p id="errorMessage" class="mb-0"></p>
|
| 162 |
+
</div>
|
| 163 |
+
</div>
|
| 164 |
+
</div>
|
| 165 |
+
</div>
|
| 166 |
+
|
| 167 |
+
<!-- Footer -->
|
| 168 |
+
<div class="text-center mt-5">
|
| 169 |
+
<p class="text-white-50">
|
| 170 |
+
<i class="fas fa-robot me-2"></i>
|
| 171 |
+
Powered by BLIP2-Vicuna AI Model
|
| 172 |
+
</p>
|
| 173 |
+
</div>
|
| 174 |
+
</div>
|
| 175 |
+
</div>
|
| 176 |
+
</div>
|
| 177 |
+
</div>
|
| 178 |
+
|
| 179 |
+
<!-- Scripts -->
|
| 180 |
+
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"></script>
|
| 181 |
+
<script>
|
| 182 |
+
document.addEventListener('DOMContentLoaded', function() {
|
| 183 |
+
const uploadForm = document.getElementById('uploadForm');
|
| 184 |
+
const loadingSpinner = document.getElementById('loadingSpinner');
|
| 185 |
+
const resultsSection = document.getElementById('resultsSection');
|
| 186 |
+
const errorSection = document.getElementById('errorSection');
|
| 187 |
+
const tryAgainBtn = document.getElementById('tryAgainBtn');
|
| 188 |
+
const analyzeBtn = document.getElementById('analyzeBtn');
|
| 189 |
+
|
| 190 |
+
uploadForm.addEventListener('submit', async function(e) {
|
| 191 |
+
e.preventDefault();
|
| 192 |
+
|
| 193 |
+
// Hide previous results
|
| 194 |
+
resultsSection.classList.add('d-none');
|
| 195 |
+
errorSection.classList.add('d-none');
|
| 196 |
+
|
| 197 |
+
// Show loading
|
| 198 |
+
loadingSpinner.classList.remove('d-none');
|
| 199 |
+
analyzeBtn.disabled = true;
|
| 200 |
+
|
| 201 |
+
try {
|
| 202 |
+
const formData = new FormData(uploadForm);
|
| 203 |
+
|
| 204 |
+
const response = await fetch('/predict', {
|
| 205 |
+
method: 'POST',
|
| 206 |
+
body: formData
|
| 207 |
+
});
|
| 208 |
+
|
| 209 |
+
const result = await response.json();
|
| 210 |
+
|
| 211 |
+
if (result.success) {
|
| 212 |
+
// Display results
|
| 213 |
+
document.getElementById('previewImage').src = 'data:image/png;base64,' + result.image;
|
| 214 |
+
document.getElementById('emotionResult').textContent = result.emotion;
|
| 215 |
+
document.getElementById('usedPrompt').textContent = result.prompt;
|
| 216 |
+
|
| 217 |
+
resultsSection.classList.remove('d-none');
|
| 218 |
+
} else {
|
| 219 |
+
throw new Error(result.error || 'Unknown error occurred');
|
| 220 |
+
}
|
| 221 |
+
} catch (error) {
|
| 222 |
+
console.error('Error:', error);
|
| 223 |
+
document.getElementById('errorMessage').textContent = error.message;
|
| 224 |
+
errorSection.classList.remove('d-none');
|
| 225 |
+
} finally {
|
| 226 |
+
// Hide loading
|
| 227 |
+
loadingSpinner.classList.add('d-none');
|
| 228 |
+
analyzeBtn.disabled = false;
|
| 229 |
+
}
|
| 230 |
+
});
|
| 231 |
+
|
| 232 |
+
tryAgainBtn.addEventListener('click', function() {
|
| 233 |
+
resultsSection.classList.add('d-none');
|
| 234 |
+
errorSection.classList.add('d-none');
|
| 235 |
+
uploadForm.reset();
|
| 236 |
+
document.getElementById('promptInput').value = 'What emotion is shown in this image?';
|
| 237 |
+
});
|
| 238 |
+
|
| 239 |
+
// Preview image on selection
|
| 240 |
+
document.getElementById('imageInput').addEventListener('change', function(e) {
|
| 241 |
+
const file = e.target.files[0];
|
| 242 |
+
if (file) {
|
| 243 |
+
const reader = new FileReader();
|
| 244 |
+
reader.onload = function(e) {
|
| 245 |
+
// Could add image preview here if needed
|
| 246 |
+
};
|
| 247 |
+
reader.readAsDataURL(file);
|
| 248 |
+
}
|
| 249 |
+
});
|
| 250 |
+
});
|
| 251 |
+
</script>
|
| 252 |
+
</body>
|
| 253 |
+
</html>
|