Faham commited on
Commit
59d4479
·
1 Parent(s): d919881

CREATE: Dockerfile for deployment

Browse files
.dockerignore ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Git
2
+ .git
3
+ .gitignore
4
+ .gitattributes
5
+
6
+ # Python
7
+ __pycache__
8
+ *.pyc
9
+ *.pyo
10
+ *.pyd
11
+ .Python
12
+ env
13
+ pip-log.txt
14
+ pip-delete-this-directory.txt
15
+ .tox
16
+ .coverage
17
+ .coverage.*
18
+ .cache
19
+ nosetests.xml
20
+ coverage.xml
21
+ *.cover
22
+ *.log
23
+ .git
24
+ .mypy_cache
25
+ .pytest_cache
26
+ .hypothesis
27
+
28
+ # Jupyter Notebook
29
+ .ipynb_checkpoints
30
+
31
+ # Environment variables
32
+ .env
33
+ .venv
34
+ env/
35
+ venv/
36
+ ENV/
37
+ env.bak/
38
+ venv.bak/
39
+
40
+ # IDE
41
+ .vscode/
42
+ .idea/
43
+ *.swp
44
+ *.swo
45
+ *~
46
+
47
+ # OS
48
+ .DS_Store
49
+ .DS_Store?
50
+ ._*
51
+ .Spotlight-V100
52
+ .Trashes
53
+ ehthumbs.db
54
+ Thumbs.db
55
+
56
+ # Docker
57
+ Dockerfile
58
+ .dockerignore
59
+ docker-compose.yml
60
+ docker-compose.yaml
61
+
62
+ # Documentation
63
+ README*.md
64
+ DEPLOYMENT_GUIDE.md
65
+ *.md
66
+
67
+ # Deployment scripts
68
+ deploy_to_spaces.py
69
+ app_spaces.py
70
+ requirements_spaces.txt
71
+ README_spaces.md
72
+
73
+ # Large files that shouldn't be in container
74
+ *.pth
75
+ *.bin
76
+ *.safetensors
77
+ *.ckpt
78
+ *.h5
79
+ *.hdf5
80
+ *.pkl
81
+ *.pickle
82
+ *.joblib
83
+ *.model
84
+ *.weights
85
+ *.pt
86
+ *.onnx
87
+ *.tflite
88
+ *.pb
89
+ *.savedmodel
90
+ *.mar
91
+ *.mlmodel
92
+ *.mlpackage
93
+ *.mlflow
94
+ *.bundle
95
+
96
+ # Archives
97
+ *.zip
98
+ *.tar.gz
99
+ *.rar
100
+ *.7z
101
+ *.gz
102
+ *.bz2
103
+ *.xz
104
+ *.lzma
105
+ *.zst
106
+ *.lz4
107
+ *.br
108
+
109
+ # Temporary files
110
+ *.tmp
111
+ *.temp
112
+ temp/
113
+ tmp/
.gitattributes ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
3
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
4
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
5
+ *.h5 filter=lfs diff=lfs merge=lfs -text
6
+ *.hdf5 filter=lfs diff=lfs merge=lfs -text
7
+ *.pkl filter=lfs diff=lfs merge=lfs -text
8
+ *.pickle filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.weights filter=lfs diff=lfs merge=lfs -text
12
+ *.pt filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.tflite filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.savedmodel filter=lfs diff=lfs merge=lfs -text
17
+ *.mar filter=lfs diff=lfs merge=lfs -text
18
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
19
+ *.mlpackage filter=lfs diff=lfs merge=lfs -text
20
+ *.mlflow filter=lfs diff=lfs merge=lfs -text
21
+ *.bundle filter=lfs diff=lfs merge=lfs -text
22
+ *.zip filter=lfs diff=lfs merge=lfs -text
23
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.7z filter=lfs diff=lfs merge=lfs -text
26
+ *.gz filter=lfs diff=lfs merge=lfs -text
27
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.lzma filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *.lz4 filter=lfs diff=lfs merge=lfs -text
32
+ *.br filter=lfs diff=lfs merge=lfs -text
33
+ *.z filter=lfs diff=lfs merge=lfs -text
34
+ *.Z filter=lfs diff=lfs merge=lfs -text
35
+ *.ar filter=lfs diff=lfs merge=lfs -text
36
+ *.deb filter=lfs diff=lfs merge=lfs -text
37
+ *.rpm filter=lfs diff=lfs merge=lfs -text
38
+ *.iso filter=lfs diff=lfs merge=lfs -text
39
+ *.img filter=lfs diff=lfs merge=lfs -text
40
+ *.vmdk filter=lfs diff=lfs merge=lfs -text
41
+ *.vdi filter=lfs diff=lfs merge=lfs -text
42
+ *.vhd filter=lfs diff=lfs merge=lfs -text
43
+ *.vhdx filter=lfs diff=lfs merge=lfs -text
44
+ *.qcow2 filter=lfs diff=lfs merge=lfs -text
45
+ *.raw filter=lfs diff=lfs merge=lfs -text
46
+ *.dmg filter=lfs diff=lfs merge=lfs -text
47
+ *.pkg filter=lfs diff=lfs diff=lfs merge=lfs -text
48
+ *.exe filter=lfs diff=lfs merge=lfs -text
49
+ *.msi filter=lfs diff=lfs merge=lfs -text
50
+ *.app filter=lfs diff=lfs merge=lfs -text
51
+ *.dll filter=lfs diff=lfs merge=lfs -text
52
+ *.so filter=lfs diff=lfs merge=lfs -text
53
+ *.dylib filter=lfs diff=lfs merge=lfs -text
54
+ *.a filter=lfs diff=lfs merge=lfs -text
55
+ *.lib filter=lfs diff=lfs diff=lfs merge=lfs -text
56
+ *.o filter=lfs diff=lfs merge=lfs -text
57
+ *.obj filter=lfs diff=lfs merge=lfs -text
58
+ *.ko filter=lfs diff=lfs merge=lfs -text
59
+ *.elf filter=lfs diff=lfs merge=lfs -text
60
+ *.bin filter=lfs diff=lfs merge=lfs -text
61
+ *.hex filter=lfs diff=lfs merge=lfs -text
62
+ *.s19 filter=lfs diff=lfs merge=lfs -text
63
+ *.ihex filter=lfs diff=lfs merge=lfs -text
64
+ *.mot filter=lfs diff=lfs merge=lfs -text
65
+ *.srec filter=lfs diff=lfs merge=lfs -text
66
+ *.uboot filter=lfs diff=lfs merge=lfs -text
67
+ *.img filter=lfs diff=lfs merge=lfs -text
68
+ *.iso filter=lfs diff=lfs merge=lfs -text
69
+ *.vmdk filter=lfs diff=lfs merge=lfs -text
70
+ *.vdi filter=lfs diff=lfs merge=lfs -text
71
+ *.vhd filter=lfs diff=lfs merge=lfs -text
72
+ *.vhdx filter=lfs diff=lfs merge=lfs -text
73
+ *.qcow2 filter=lfs diff=lfs merge=lfs -text
74
+ *.raw filter=lfs diff=lfs merge=lfs -text
75
+ *.dmg filter=lfs diff=lfs merge=lfs -text
76
+ *.pkg filter=lfs diff=lfs diff=lfs merge=lfs -text
77
+ *.exe filter=lfs diff=lfs merge=lfs -text
78
+ *.msi filter=lfs diff=lfs merge=lfs -text
79
+ *.app filter=lfs diff=lfs merge=lfs -text
80
+ *.dll filter=lfs diff=lfs merge=lfs -text
81
+ *.so filter=lfs diff=lfs merge=lfs -text
82
+ *.dylib filter=lfs diff=lfs merge=lfs -text
83
+ *.a filter=lfs diff=lfs merge=lfs -text
84
+ *.lib filter=lfs diff=lfs diff=lfs merge=lfs -text
85
+ *.o filter=lfs diff=lfs merge=lfs -text
86
+ *.obj filter=lfs diff=lfs merge=lfs -text
87
+ *.ko filter=lfs diff=lfs merge=lfs -text
88
+ *.elf filter=lfs diff=lfs merge=lfs -text
89
+ *.bin filter=lfs diff=lfs merge=lfs -text
90
+ *.hex filter=lfs diff=lfs merge=lfs -text
91
+ *.s19 filter=lfs diff=lfs merge=lfs -text
92
+ *.ihex filter=lfs diff=lfs merge=lfs -text
93
+ *.mot filter=lfs diff=lfs merge=lfs -text
94
+ *.srec filter=lfs diff=lfs merge=lfs -text
95
+ *.uboot filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.9 slim image
2
+ FROM python:3.9-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies for OpenCV and audio
8
+ RUN apt-get update && apt-get install -y \
9
+ libgl1-mesa-glx \
10
+ libglib2.0-0 \
11
+ libsm6 \
12
+ libxext6 \
13
+ libxrender-dev \
14
+ libavcodec-dev \
15
+ libavformat-dev \
16
+ libswscale-dev \
17
+ libv4l-dev \
18
+ libxvidcore-dev \
19
+ libx264-dev \
20
+ libjpeg-dev \
21
+ libpng-dev \
22
+ libtiff-dev \
23
+ && rm -rf /var/lib/apt/lists/*
24
+
25
+ # Copy requirements and install Python dependencies
26
+ COPY requirements.txt .
27
+ RUN pip install --no-cache-dir -r requirements.txt
28
+
29
+ # Copy the app
30
+ COPY app.py .
31
+ COPY simple_model_manager.py .
32
+
33
+ # Expose port
34
+ EXPOSE 7860
35
+
36
+ # Run Streamlit
37
+ CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
debug_drive.py DELETED
@@ -1,185 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Debug Google Drive download issues
4
- """
5
-
6
- import os
7
- import requests
8
- import re
9
- from pathlib import Path
10
-
11
-
12
- # Load .env file manually
13
- def load_env():
14
- env_file = Path(".env")
15
- if env_file.exists():
16
- with open(env_file, "r") as f:
17
- for line in f:
18
- line = line.strip()
19
- if line and not line.startswith("#") and "=" in line:
20
- key, value = line.split("=", 1)
21
- os.environ[key.strip()] = value.strip().strip('"')
22
-
23
-
24
- def test_drive_bypass(file_id):
25
- """Test different bypass methods"""
26
- print(f"Testing file ID: {file_id}")
27
- print("=" * 50)
28
-
29
- # Method 1: Direct bypass
30
- print("\n1. Testing direct bypass...")
31
- try:
32
- url = f"https://drive.usercontent.google.com/download?id={file_id}&export=download&confirm=t"
33
- response = requests.get(url, stream=True)
34
- print(f"Status: {response.status_code}")
35
- print(f"Content-Type: {response.headers.get('content-type', 'Unknown')}")
36
-
37
- first_chunk = next(response.iter_content(chunk_size=1024), b"")
38
- if first_chunk.startswith(b"<!DOCTYPE") or first_chunk.startswith(b"<html"):
39
- print("❌ Still getting HTML")
40
- html_content = first_chunk.decode("utf-8", errors="ignore")
41
- print(f"HTML preview: {html_content[:200]}...")
42
- else:
43
- print("✅ Got file content!")
44
- print(f"First bytes: {first_chunk[:50]}")
45
- return True
46
- except Exception as e:
47
- print(f"❌ Error: {e}")
48
-
49
- # Method 2: Session-based approach
50
- print("\n2. Testing session-based approach...")
51
- try:
52
- session = requests.Session()
53
- session.headers.update(
54
- {
55
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
56
- "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
57
- "Accept-Language": "en-US,en;q=0.5",
58
- }
59
- )
60
-
61
- # First get the virus scan page
62
- virus_url = f"https://drive.google.com/uc?export=download&id={file_id}"
63
- response = session.get(virus_url)
64
- print(f"Virus page status: {response.status_code}")
65
-
66
- # Extract confirm and UUID
67
- html_content = response.text
68
- confirm_match = re.search(r'name="confirm" value="([^"]+)"', html_content)
69
- uuid_match = re.search(r'name="uuid" value="([^"]+)"', html_content)
70
-
71
- if confirm_match and uuid_match:
72
- confirm_value = confirm_match.group(1)
73
- uuid_value = uuid_match.group(1)
74
- print(f"Found confirm: {confirm_value}")
75
- print(f"Found UUID: {uuid_value}")
76
-
77
- # Submit form
78
- form_data = {
79
- "id": file_id,
80
- "export": "download",
81
- "confirm": confirm_value,
82
- "uuid": uuid_value,
83
- }
84
- form_url = "https://drive.usercontent.google.com/download"
85
- response = session.post(form_url, data=form_data, stream=True)
86
-
87
- print(f"Form submission status: {response.status_code}")
88
- first_chunk = next(response.iter_content(chunk_size=1024), b"")
89
-
90
- if first_chunk.startswith(b"<!DOCTYPE") or first_chunk.startswith(b"<html"):
91
- print("❌ Form submission still returned HTML")
92
- else:
93
- print("✅ Form submission successful!")
94
- return True
95
- else:
96
- print("❌ Could not extract confirm/UUID")
97
- print(f"HTML preview: {html_content[:300]}...")
98
-
99
- except Exception as e:
100
- print(f"❌ Error: {e}")
101
-
102
- # Method 3: Extract download URL from file page
103
- print("\n3. Testing file page extraction...")
104
- try:
105
- session = requests.Session()
106
- session.headers.update(
107
- {
108
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
109
- "Referer": "https://drive.google.com/",
110
- }
111
- )
112
-
113
- file_url = f"https://drive.google.com/file/d/{file_id}/view"
114
- response = session.get(file_url)
115
- print(f"File page status: {response.status_code}")
116
-
117
- if response.status_code == 200:
118
- # Look for download URL in the page
119
- download_match = re.search(r'"downloadUrl":"([^"]+)"', response.text)
120
- if download_match:
121
- download_url = (
122
- download_match.group(1)
123
- .replace("\\u003d", "=")
124
- .replace("\\u0026", "&")
125
- )
126
- print(f"Found download URL: {download_url}")
127
-
128
- # Try downloading from this URL
129
- response = session.get(download_url, stream=True)
130
- first_chunk = next(response.iter_content(chunk_size=1024), b"")
131
-
132
- if first_chunk.startswith(b"<!DOCTYPE") or first_chunk.startswith(
133
- b"<html"
134
- ):
135
- print("❌ Download URL still returned HTML")
136
- else:
137
- print("✅ Download URL successful!")
138
- return True
139
- else:
140
- print("❌ Could not find download URL in page")
141
- else:
142
- print(f"❌ Could not access file page")
143
-
144
- except Exception as e:
145
- print(f"❌ Error: {e}")
146
-
147
- print("\n❌ All methods failed")
148
- return False
149
-
150
-
151
- def main():
152
- print("Google Drive Bypass Debug Tool")
153
- print("=" * 50)
154
-
155
- # Load environment variables
156
- load_env()
157
-
158
- # Get file ID from environment or user input
159
- vision_url = os.getenv("VISION_MODEL_DRIVE_LINK", "")
160
- audio_url = os.getenv("AUDIO_MODEL_DRIVE_LINK", "")
161
-
162
- if not vision_url and not audio_url:
163
- print("❌ No environment variables found!")
164
- print("Please set VISION_MODEL_DRIVE_LINK or AUDIO_MODEL_DRIVE_LINK")
165
- return
166
-
167
- if vision_url:
168
- print(f"\nTesting Vision Model URL: {vision_url}")
169
- if "/file/d/" in vision_url:
170
- file_id = vision_url.split("/file/d/")[1].split("/")[0]
171
- test_drive_bypass(file_id)
172
- else:
173
- print("❌ Invalid vision model URL format")
174
-
175
- if audio_url:
176
- print(f"\nTesting Audio Model URL: {audio_url}")
177
- if "/file/d/" in audio_url:
178
- file_id = audio_url.split("/file/d/")[1].split("/")[0]
179
- test_drive_bypass(file_id)
180
- else:
181
- print("❌ Invalid audio model URL format")
182
-
183
-
184
- if __name__ == "__main__":
185
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_app.py DELETED
@@ -1,65 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Startup script for the Sentiment Analysis Testing Ground Streamlit application.
4
- This script provides an easy way to launch the application with proper configuration.
5
- """
6
-
7
- import subprocess
8
- import sys
9
- import os
10
-
11
-
12
- def main():
13
- """Main function to start the Streamlit application."""
14
-
15
- print("🧠 Starting Sentiment Analysis Testing Ground...")
16
- print("=" * 50)
17
-
18
- # Check if app.py exists
19
- if not os.path.exists("app.py"):
20
- print("❌ Error: app.py not found in current directory!")
21
- print("Please make sure you're in the correct directory.")
22
- sys.exit(1)
23
-
24
- # Check if requirements are installed
25
- try:
26
- import streamlit
27
- import pandas
28
- import PIL
29
-
30
- print("✅ Dependencies check passed")
31
- except ImportError as e:
32
- print(f"❌ Missing dependency: {e}")
33
- print("Please install requirements: pip install -r requirements.txt")
34
- sys.exit(1)
35
-
36
- print("🚀 Launching Streamlit application...")
37
- print("📱 The app will open in your default browser")
38
- print("🔗 If it doesn't open automatically, go to: http://localhost:8501")
39
- print("⏹️ Press Ctrl+C to stop the application")
40
- print("=" * 50)
41
-
42
- try:
43
- # Start Streamlit with the app
44
- subprocess.run(
45
- [
46
- sys.executable,
47
- "-m",
48
- "streamlit",
49
- "run",
50
- "app.py",
51
- "--server.headless",
52
- "false",
53
- "--server.port",
54
- "8501",
55
- ]
56
- )
57
- except KeyboardInterrupt:
58
- print("\n👋 Application stopped by user")
59
- except Exception as e:
60
- print(f"❌ Error starting application: {e}")
61
- sys.exit(1)
62
-
63
-
64
- if __name__ == "__main__":
65
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
simple_model_manager.py CHANGED
@@ -1,8 +1,3 @@
1
- #!/usr/bin/env python3
2
- """
3
- Simple Model Manager - Downloads models from Google Drive using gdown
4
- """
5
-
6
  import os
7
  import gdown
8
  from pathlib import Path
@@ -11,22 +6,9 @@ from typing import Tuple, Any
11
  import torch
12
  import torch.nn as nn
13
  from torchvision import models
 
14
 
15
- # Try to load .env file if it exists
16
- try:
17
- from dotenv import load_dotenv
18
-
19
- load_dotenv()
20
- except ImportError:
21
- # If python-dotenv is not installed, try to load .env manually
22
- env_file = Path(".env")
23
- if env_file.exists():
24
- with open(env_file, "r") as f:
25
- for line in f:
26
- line = line.strip()
27
- if line and not line.startswith("#") and "=" in line:
28
- key, value = line.split("=", 1)
29
- os.environ[key.strip()] = value.strip()
30
 
31
  # Configure logging
32
  logging.basicConfig(level=logging.INFO)
 
 
 
 
 
 
1
  import os
2
  import gdown
3
  from pathlib import Path
 
6
  import torch
7
  import torch.nn as nn
8
  from torchvision import models
9
+ from dotenv import load_dotenv
10
 
11
+ load_dotenv()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
test_audio_model.py DELETED
@@ -1,173 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Test script for the Wav2Vec2 audio sentiment analysis model
4
- """
5
-
6
- import os
7
- import torch
8
- import numpy as np
9
- import librosa
10
- from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
11
- import tempfile
12
-
13
-
14
- def test_audio_model():
15
- """Test the audio model loading and inference"""
16
-
17
- print("🔊 Testing Wav2Vec2 Audio Sentiment Model")
18
- print("=" * 50)
19
-
20
- # Check if model file exists
21
- model_path = "models/wav2vec2_model.pth"
22
- if not os.path.exists(model_path):
23
- print(f"❌ Audio model file not found at: {model_path}")
24
- return False
25
-
26
- print(f"✅ Found model file: {model_path}")
27
-
28
- try:
29
- # Set device
30
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- print(f"🖥️ Using device: {device}")
32
-
33
- # Load the model checkpoint to check architecture
34
- checkpoint = torch.load(model_path, map_location=device)
35
- print(f"📊 Checkpoint keys: {list(checkpoint.keys())}")
36
-
37
- # Check for classifier weights
38
- if "classifier.weight" in checkpoint:
39
- num_classes = checkpoint["classifier.weight"].shape[0]
40
- print(f"📊 Model has {num_classes} output classes")
41
- else:
42
- print("⚠️ Could not determine number of classes from checkpoint")
43
- num_classes = 3 # Default assumption
44
-
45
- # Initialize model
46
- print("🔄 Initializing Wav2Vec2 model...")
47
- model_checkpoint = "facebook/wav2vec2-base"
48
- model = AutoModelForAudioClassification.from_pretrained(
49
- model_checkpoint, num_labels=num_classes
50
- )
51
-
52
- # Load trained weights
53
- print("🔄 Loading trained weights...")
54
- model.load_state_dict(checkpoint)
55
- model.to(device)
56
- model.eval()
57
-
58
- print("✅ Model loaded successfully!")
59
-
60
- # Test with dummy audio
61
- print("🧪 Testing inference with dummy audio...")
62
-
63
- # Create dummy audio (1 second of random noise at 16kHz)
64
- dummy_audio = np.random.randn(16000).astype(np.float32)
65
-
66
- # Load feature extractor
67
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
68
-
69
- # Preprocess audio
70
- inputs = feature_extractor(
71
- dummy_audio,
72
- sampling_rate=16000,
73
- max_length=80000, # 5 seconds * 16000 Hz
74
- truncation=True,
75
- padding="max_length",
76
- return_tensors="pt",
77
- )
78
-
79
- # Move to device
80
- input_values = inputs.input_values.to(device)
81
-
82
- # Run inference
83
- with torch.no_grad():
84
- outputs = model(input_values)
85
- probabilities = torch.softmax(outputs.logits, dim=1)
86
- confidence, predicted = torch.max(probabilities, 1)
87
-
88
- print(f"🔍 Model output shape: {outputs.logits.shape}")
89
- print(f"🎯 Predicted class: {predicted.item()}")
90
- print(f"📊 Confidence: {confidence.item():.3f}")
91
- print(f"📈 All probabilities: {probabilities.squeeze().cpu().numpy()}")
92
-
93
- # Sentiment mapping
94
- sentiment_map = {0: "Negative", 1: "Neutral", 2: "Positive"}
95
- predicted_sentiment = sentiment_map.get(
96
- predicted.item(), f"Class_{predicted.item()}"
97
- )
98
- print(f"😊 Predicted sentiment: {predicted_sentiment}")
99
-
100
- print("✅ Audio model test completed successfully!")
101
- return True
102
-
103
- except Exception as e:
104
- print(f"❌ Error testing audio model: {str(e)}")
105
- import traceback
106
-
107
- traceback.print_exc()
108
- return False
109
-
110
-
111
- def check_audio_model_file():
112
- """Check the audio model file details"""
113
-
114
- print("\n🔍 Audio Model File Analysis")
115
- print("=" * 30)
116
-
117
- model_path = "models/wav2vec2_model.pth"
118
- if not os.path.exists(model_path):
119
- print(f"❌ Model file not found: {model_path}")
120
- return
121
-
122
- # File size
123
- file_size = os.path.getsize(model_path) / (1024 * 1024) # MB
124
- print(f"📁 File size: {file_size:.1f} MB")
125
-
126
- try:
127
- # Load checkpoint
128
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
129
- checkpoint = torch.load(model_path, map_location=device)
130
-
131
- print(f"🔑 Checkpoint keys ({len(checkpoint)} total):")
132
- for key, value in checkpoint.items():
133
- if isinstance(value, torch.Tensor):
134
- print(f" - {key}: {value.shape} ({value.dtype})")
135
- else:
136
- print(f" - {key}: {type(value)}")
137
-
138
- # Check classifier
139
- if "classifier.weight" in checkpoint:
140
- num_classes = checkpoint["classifier.weight"].shape[0]
141
- print(f"\n🎯 Classifier output classes: {num_classes}")
142
- print(
143
- f"📊 Classifier weight shape: {checkpoint['classifier.weight'].shape}"
144
- )
145
- if "classifier.bias" in checkpoint:
146
- print(
147
- f"📊 Classifier bias shape: {checkpoint['classifier.bias'].shape}"
148
- )
149
-
150
- # Check wav2vec2 base model
151
- if "wav2vec2.feature_extractor.conv_layers.0.conv.weight" in checkpoint:
152
- print(f"🔊 Wav2Vec2 base model: Present")
153
-
154
- except Exception as e:
155
- print(f"❌ Error analyzing checkpoint: {str(e)}")
156
-
157
-
158
- if __name__ == "__main__":
159
- print("🚀 Starting Wav2Vec2 Audio Model Tests")
160
- print("=" * 60)
161
-
162
- # Check model file
163
- check_audio_model_file()
164
-
165
- print("\n" + "=" * 60)
166
-
167
- # Test model loading and inference
168
- success = test_audio_model()
169
-
170
- if success:
171
- print("\n🎉 All audio model tests passed!")
172
- else:
173
- print("\n💥 Audio model tests failed!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_download.py DELETED
@@ -1,49 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Test the updated Google Drive download function
4
- """
5
-
6
- from simple_model_manager import SimpleModelManager
7
-
8
-
9
- def test_download():
10
- """Test the download function"""
11
- print("Testing Google Drive Download Function")
12
- print("=" * 50)
13
-
14
- # Initialize manager
15
- manager = SimpleModelManager()
16
-
17
- # Check model status
18
- status = manager.get_model_status()
19
- print("Model Status:")
20
- for model_type, info in status.items():
21
- print(f" {model_type}: {'✅' if info['configured'] else '❌'} {info['url']}")
22
- if info["cached"]:
23
- print(f" 📁 Cached: {info['filename']}")
24
-
25
- # Test vision model download
26
- if status["vision"]["configured"]:
27
- print(f"\nTesting vision model download...")
28
- try:
29
- vision_model, device, num_classes = manager.load_vision_model()
30
- print(f"✅ Vision model loaded: {num_classes} classes")
31
- except Exception as e:
32
- print(f"❌ Vision model failed: {e}")
33
- else:
34
- print("❌ Vision model not configured")
35
-
36
- # Test audio model download
37
- if status["audio"]["configured"]:
38
- print(f"\nTesting audio model download...")
39
- try:
40
- audio_model, device = manager.load_audio_model()
41
- print(f"✅ Audio model loaded")
42
- except Exception as e:
43
- print(f"❌ Audio model failed: {e}")
44
- else:
45
- print("❌ Audio model not configured")
46
-
47
-
48
- if __name__ == "__main__":
49
- test_download()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_drive_links.py DELETED
@@ -1,96 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Test Google Drive links directly to see what's being returned
4
- """
5
-
6
- import requests
7
- import os
8
- from pathlib import Path
9
-
10
-
11
- def test_drive_link(url, filename):
12
- """Test a Google Drive link directly"""
13
- print(f"\nTesting: {filename}")
14
- print(f"URL: {url}")
15
-
16
- try:
17
- # Convert Google Drive share link to direct download link
18
- if "drive.google.com" in url:
19
- if "/file/d/" in url:
20
- file_id = url.split("/file/d/")[1].split("/")[0]
21
- elif "id=" in url:
22
- file_id = url.split("id=")[1].split("&")[0]
23
- else:
24
- print("❌ Could not extract file ID")
25
- return
26
-
27
- direct_url = f"https://drive.google.com/uc?export=download&id={file_id}"
28
- print(f"Direct URL: {direct_url}")
29
- else:
30
- direct_url = url
31
-
32
- # Test the download
33
- print("Downloading...")
34
- response = requests.get(direct_url, stream=True)
35
-
36
- print(f"Status Code: {response.status_code}")
37
- print(f"Content-Type: {response.headers.get('content-type', 'Unknown')}")
38
- print(f"Content-Length: {response.headers.get('content-length', 'Unknown')}")
39
-
40
- if response.status_code == 200:
41
- # Read first 200 bytes to check content
42
- content = response.raw.read(200)
43
- print(f"First 200 bytes: {content[:100]}...")
44
-
45
- # Check if it's HTML
46
- if content.startswith(b"<!DOCTYPE") or content.startswith(b"<html"):
47
- print("❌ ERROR: This is an HTML page, not a model file!")
48
- print(" Your Google Drive link is not working properly")
49
- print(" Check file permissions and sharing settings")
50
- else:
51
- print("✅ Looks like a valid file (not HTML)")
52
-
53
- # Save a small sample to check
54
- sample_path = f"sample_{filename}"
55
- with open(sample_path, "wb") as f:
56
- f.write(content)
57
- print(f"Saved sample to: {sample_path}")
58
-
59
- else:
60
- print(f"❌ Download failed with status: {response.status_code}")
61
-
62
- except Exception as e:
63
- print(f"❌ Error: {e}")
64
-
65
-
66
- def main():
67
- print("Google Drive Link Tester")
68
- print("=" * 50)
69
-
70
- # Check environment variables
71
- vision_url = os.getenv("VISION_MODEL_DRIVE_LINK")
72
- audio_url = os.getenv("AUDIO_MODEL_DRIVE_LINK")
73
-
74
- if not vision_url and not audio_url:
75
- print("❌ No environment variables found!")
76
- print("Please run setup_env.py first or set:")
77
- print(" VISION_MODEL_DRIVE_LINK")
78
- print(" AUDIO_MODEL_DRIVE_LINK")
79
- return
80
-
81
- if vision_url:
82
- test_drive_link(vision_url, "resnet50_model.pth")
83
-
84
- if audio_url:
85
- test_drive_link(audio_url, "wav2vec2_model.pth")
86
-
87
- print("\n" + "=" * 50)
88
- print("If you see HTML content, your Google Drive links need fixing!")
89
- print("Make sure:")
90
- print(" 1. Files are set to 'Anyone with the link can view'")
91
- print(" 2. You're using direct file links, not folder links")
92
- print(" 3. Files are not too large for direct download")
93
-
94
-
95
- if __name__ == "__main__":
96
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_vision_model.py DELETED
@@ -1,136 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Test script for the vision sentiment analysis model.
4
- This script verifies that the ResNet-50 model can be loaded and run inference.
5
- """
6
-
7
- import os
8
- import sys
9
- import torch
10
- import torch.nn as nn
11
- from torchvision import transforms, models
12
- from PIL import Image
13
- import numpy as np
14
-
15
-
16
- def get_sentiment_mapping(num_classes):
17
- """Get the sentiment mapping based on number of classes"""
18
- if num_classes == 3:
19
- return {0: "Negative", 1: "Neutral", 2: "Positive"}
20
- elif num_classes == 4:
21
- # Common 4-class emotion mapping
22
- return {0: "Angry", 1: "Sad", 2: "Happy", 3: "Neutral"}
23
- elif num_classes == 7:
24
- # FER2013 7-class emotion mapping
25
- return {0: "Angry", 1: "Disgust", 2: "Fear", 3: "Happy", 4: "Sad", 5: "Surprise", 6: "Neutral"}
26
- else:
27
- # Generic mapping for unknown number of classes
28
- return {i: f"Class_{i}" for i in range(num_classes)}
29
-
30
-
31
- def test_vision_model():
32
- """Test the vision sentiment analysis model"""
33
-
34
- print("🧠 Testing Vision Sentiment Analysis Model")
35
- print("=" * 50)
36
-
37
- # Check if model file exists
38
- model_path = "models/resnet50_model.pth"
39
- if not os.path.exists(model_path):
40
- print(f"❌ Model file not found: {model_path}")
41
- print("Please ensure the model file exists in the models/ directory")
42
- return False
43
-
44
- print(f"✅ Model file found: {model_path}")
45
-
46
- try:
47
- # Load the model weights first to check the architecture
48
- print("📥 Loading model checkpoint...")
49
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
- checkpoint = torch.load(model_path, map_location=device)
51
-
52
- # Check the number of classes from the checkpoint
53
- if 'fc.weight' in checkpoint:
54
- num_classes = checkpoint['fc.weight'].shape[0]
55
- print(f"📊 Model checkpoint has {num_classes} output classes")
56
- else:
57
- # Fallback: try to infer from the last layer
58
- num_classes = 3 # Default assumption
59
- print("⚠️ Could not determine number of classes from checkpoint, assuming 3")
60
-
61
- # Initialize ResNet-50 model with the correct number of classes
62
- print("🔧 Initializing ResNet-50 model...")
63
- model = models.resnet50(weights=None) # Don't load ImageNet weights
64
- num_ftrs = model.fc.in_features
65
- model.fc = nn.Linear(num_ftrs, num_classes) # Use actual number of classes
66
-
67
- print(f"📥 Loading trained weights for {num_classes} classes...")
68
- model.load_state_dict(checkpoint)
69
- model.to(device)
70
- model.eval()
71
-
72
- print(f"✅ Model loaded successfully with {num_classes} classes!")
73
- print(f"🖥️ Using device: {device}")
74
-
75
- # Test with a dummy image
76
- print("🧪 Testing inference with dummy image...")
77
-
78
- # Create a dummy image (224x224 RGB)
79
- dummy_image = Image.fromarray(
80
- np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
81
- )
82
-
83
- # Apply transforms
84
- transform = transforms.Compose(
85
- [
86
- transforms.Resize(224),
87
- transforms.CenterCrop(224),
88
- transforms.ToTensor(),
89
- transforms.Normalize(
90
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
91
- ),
92
- ]
93
- )
94
-
95
- image_tensor = transform(dummy_image).unsqueeze(0).to(device)
96
-
97
- # Run inference
98
- with torch.no_grad():
99
- outputs = model(image_tensor)
100
- print(f"🔍 Model output shape: {outputs.shape}")
101
-
102
- probabilities = torch.nn.functional.softmax(outputs, dim=1)
103
- confidence, predicted = torch.max(probabilities, 1)
104
-
105
- # Get sentiment mapping based on number of classes
106
- sentiment_map = get_sentiment_mapping(num_classes)
107
- sentiment = sentiment_map[predicted.item()]
108
- confidence_score = confidence.item()
109
-
110
- print(f"🎯 Test prediction: {sentiment} (confidence: {confidence_score:.3f})")
111
- print(f"📋 Available classes: {list(sentiment_map.values())}")
112
- print("✅ Inference test passed!")
113
-
114
- return True
115
-
116
- except Exception as e:
117
- print(f"❌ Error testing model: {str(e)}")
118
- import traceback
119
- traceback.print_exc()
120
- return False
121
-
122
-
123
- def main():
124
- """Main function"""
125
- success = test_vision_model()
126
-
127
- if success:
128
- print("\n🎉 All tests passed! The vision model is ready to use.")
129
- print("You can now run the Streamlit app with: streamlit run app.py")
130
- else:
131
- print("\n💥 Tests failed. Please check the error messages above.")
132
- sys.exit(1)
133
-
134
-
135
- if __name__ == "__main__":
136
- main()