Add Dockerized Flask application files for deployment
Browse files- Dockerfile +66 -0
- flask_Character.py +31 -0
- requirements.txt +7 -0
- utils/__pycache__/config.cpython-311.pyc +0 -0
- utils/__pycache__/config.cpython-312.pyc +0 -0
- utils/__pycache__/generate_face_shapes.cpython-311.pyc +0 -0
- utils/__pycache__/generate_face_shapes.cpython-312.pyc +0 -0
- utils/audio/extraction/__pycache__/extract_features.cpython-311.pyc +0 -0
- utils/audio/extraction/__pycache__/extract_features.cpython-312.pyc +0 -0
- utils/audio/extraction/extract_features.py +245 -0
- utils/audio/processing/__pycache__/audio_processing.cpython-311.pyc +0 -0
- utils/audio/processing/__pycache__/audio_processing.cpython-312.pyc +0 -0
- utils/audio/processing/audio_processing.py +166 -0
- utils/config.py +12 -0
- utils/generate_face_shapes.py +22 -0
- utils/model/__pycache__/model.cpython-311.pyc +0 -0
- utils/model/__pycache__/model.cpython-312.pyc +0 -0
- utils/model/model.pth +3 -0
- utils/model/model.py +256 -0
Dockerfile
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use an official Python runtime as a parent image
|
| 2 |
+
# We are using a lightweight Python 3.9 image.
|
| 3 |
+
# Using a specific version tag (like 3.9-slim) is recommended for stability
|
| 4 |
+
# and to ensure your build is reproducible.
|
| 5 |
+
FROM python:3.9-slim
|
| 6 |
+
|
| 7 |
+
# Set environment variables
|
| 8 |
+
# These environment variables are commonly used for Python applications
|
| 9 |
+
# running in Docker to ensure output is unbuffered and to manage pip behavior.
|
| 10 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 11 |
+
PIP_NO_CACHE_DIR=off \
|
| 12 |
+
PIP_DISABLE_PIP_VERSION_CHECK=on \
|
| 13 |
+
PIP_DEFAULT_TIMEOUT=100
|
| 14 |
+
|
| 15 |
+
# Set the working directory in the container
|
| 16 |
+
# This sets the current directory inside the container to /app.
|
| 17 |
+
# All subsequent commands like COPY, RUN, and CMD will be executed relative to this directory.
|
| 18 |
+
WORKDIR /app
|
| 19 |
+
|
| 20 |
+
# Copy the requirements file first to leverage Docker's layer caching
|
| 21 |
+
# This is an optimization. If only your code changes, Docker can use the cached
|
| 22 |
+
# layer for dependency installation, speeding up subsequent builds.
|
| 23 |
+
COPY requirements.txt .
|
| 24 |
+
|
| 25 |
+
# Install the Python dependencies
|
| 26 |
+
# This command executes during the image build. It reads the requirements.txt file
|
| 27 |
+
# and installs all the listed Python packages using pip.
|
| 28 |
+
# Make sure you have 'Flask', 'numpy', 'torch', 'gunicorn', and any other
|
| 29 |
+
# necessary libraries listed in your requirements.txt file.
|
| 30 |
+
RUN pip install -r requirements.txt
|
| 31 |
+
|
| 32 |
+
# Copy the rest of your application code into the container
|
| 33 |
+
# This copies all other files and directories from your local project's root
|
| 34 |
+
# (where the Dockerfile is located) into the /app directory inside the container.
|
| 35 |
+
# This includes your main Flask file (your_app.py), the utils directory,
|
| 36 |
+
# the model files (utils/model/model.pth, utils/model/model.py), etc.
|
| 37 |
+
COPY . /app
|
| 38 |
+
|
| 39 |
+
# Expose the port your Flask app will run on
|
| 40 |
+
# This instruction informs Docker that the container listens on port 7860.
|
| 41 |
+
# Hugging Face Spaces Docker SDK typically expects applications to listen on port 7860.
|
| 42 |
+
# This doesn't actually publish the port, but serves as documentation.
|
| 43 |
+
EXPOSE 7860
|
| 44 |
+
|
| 45 |
+
# Command to run your application when the container launches
|
| 46 |
+
# This is the default command that will be executed when a container is started
|
| 47 |
+
# from this image. We use gunicorn, a popular WSGI server for Python web apps.
|
| 48 |
+
# It tells gunicorn to run the 'app' object found in your 'your_app.py' file.
|
| 49 |
+
# '--workers 4': Specifies the number of worker processes for gunicorn. Adjust as needed.
|
| 50 |
+
# '--bind 0.0.0.0:7860': Binds gunicorn to all network interfaces on port 7860.
|
| 51 |
+
# 'your_app:app': The format is [module_name]:[variable_name].
|
| 52 |
+
# - 'your_app' should be the name of your main Python file (without the .py extension).
|
| 53 |
+
# - 'app' should be the name of your Flask application instance
|
| 54 |
+
# (e.g., app = flask.Flask(__name__)).
|
| 55 |
+
CMD ["gunicorn", "--workers", "4", "--bind", "0.0.0.0:7860", "flask_Character:app"]
|
| 56 |
+
|
| 57 |
+
# Important Note:
|
| 58 |
+
# Before building your Docker image, make sure to remove the following block
|
| 59 |
+
# from your main Flask application file (your_app.py):
|
| 60 |
+
#
|
| 61 |
+
# if __name__ == '__main__':
|
| 62 |
+
# app.run(host='127.0.0.1', port=5000)
|
| 63 |
+
#
|
| 64 |
+
# This block is for running the Flask development server directly,
|
| 65 |
+
# but when using a WSGI server like Gunicorn in production (or in Docker),
|
| 66 |
+
# Gunicorn handles starting the application.
|
flask_Character.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# This software is licensed under a **dual-license model**
|
| 3 |
+
# For individuals and businesses earning **under $1M per year**, this software is licensed under the **MIT License**
|
| 4 |
+
# Businesses or organizations with **annual revenue of $1,000,000 or more** must obtain permission to use this software commercially.
|
| 5 |
+
|
| 6 |
+
from flask import request, jsonify
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import flask
|
| 10 |
+
|
| 11 |
+
from utils.generate_face_shapes import generate_facial_data_from_bytes
|
| 12 |
+
from utils.model.model import load_model
|
| 13 |
+
from utils.config import config
|
| 14 |
+
|
| 15 |
+
app = flask.Flask(__name__)
|
| 16 |
+
|
| 17 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 18 |
+
print("Activated device:", device)
|
| 19 |
+
|
| 20 |
+
model_path = 'utils/model/model.pth'
|
| 21 |
+
blendshape_model = load_model(model_path, config, device)
|
| 22 |
+
|
| 23 |
+
@app.route('/audio_to_blendshapes', methods=['POST'])
|
| 24 |
+
def audio_to_blendshapes_route():
|
| 25 |
+
audio_bytes = request.data
|
| 26 |
+
generated_facial_data = generate_facial_data_from_bytes(audio_bytes, blendshape_model, device, config)
|
| 27 |
+
generated_facial_data_list = generated_facial_data.tolist() if isinstance(generated_facial_data, np.ndarray) else generated_facial_data
|
| 28 |
+
|
| 29 |
+
return jsonify({'blendshapes': generated_facial_data_list})
|
| 30 |
+
|
| 31 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Flask
|
| 2 |
+
numpy
|
| 3 |
+
torch
|
| 4 |
+
gunicorn
|
| 5 |
+
|
| 6 |
+
librosa
|
| 7 |
+
scipy
|
utils/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (403 Bytes). View file
|
|
|
utils/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (392 Bytes). View file
|
|
|
utils/__pycache__/generate_face_shapes.cpython-311.pyc
ADDED
|
Binary file (882 Bytes). View file
|
|
|
utils/__pycache__/generate_face_shapes.cpython-312.pyc
ADDED
|
Binary file (798 Bytes). View file
|
|
|
utils/audio/extraction/__pycache__/extract_features.cpython-311.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
utils/audio/extraction/__pycache__/extract_features.cpython-312.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
utils/audio/extraction/extract_features.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This software is licensed under a **dual-license model**
|
| 2 |
+
# For individuals and businesses earning **under $1M per year**, this software is licensed under the **MIT License**
|
| 3 |
+
# Businesses or organizations with **annual revenue of $1,000,000 or more** must obtain permission to use this software commercially.
|
| 4 |
+
|
| 5 |
+
# extract_features.py
|
| 6 |
+
import io
|
| 7 |
+
import librosa
|
| 8 |
+
import numpy as np
|
| 9 |
+
import scipy.signal
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def extract_audio_features(audio_input, sr=88200, from_bytes=False):
|
| 13 |
+
try:
|
| 14 |
+
if from_bytes:
|
| 15 |
+
y, sr = load_audio_from_bytes(audio_input, sr)
|
| 16 |
+
else:
|
| 17 |
+
y, sr = load_and_preprocess_audio(audio_input, sr)
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f"Loading as WAV failed: {e}\nFalling back to PCM loading.")
|
| 20 |
+
y = load_pcm_audio_from_bytes(audio_input)
|
| 21 |
+
|
| 22 |
+
frame_length = int(0.01667 * sr) # Frame length set to 0.01667 seconds (~60 fps)
|
| 23 |
+
hop_length = frame_length // 2 # 2x overlap for smoother transitions
|
| 24 |
+
min_frames = 9 # Minimum number of frames needed for delta calculation
|
| 25 |
+
|
| 26 |
+
num_frames = (len(y) - frame_length) // hop_length + 1
|
| 27 |
+
|
| 28 |
+
if num_frames < min_frames:
|
| 29 |
+
print(f"Audio file is too short: {num_frames} frames, required: {min_frames} frames")
|
| 30 |
+
return None, None
|
| 31 |
+
|
| 32 |
+
combined_features = extract_and_combine_features(y, sr, frame_length, hop_length)
|
| 33 |
+
|
| 34 |
+
return combined_features, y
|
| 35 |
+
|
| 36 |
+
def extract_and_combine_features(y, sr, frame_length, hop_length, include_autocorr=True):
|
| 37 |
+
|
| 38 |
+
all_features = []
|
| 39 |
+
mfcc_features = extract_mfcc_features(y, sr, frame_length, hop_length)
|
| 40 |
+
all_features.append(mfcc_features)
|
| 41 |
+
|
| 42 |
+
if include_autocorr:
|
| 43 |
+
autocorr_features = extract_autocorrelation_features(
|
| 44 |
+
y, sr, frame_length, hop_length
|
| 45 |
+
)
|
| 46 |
+
all_features.append(autocorr_features)
|
| 47 |
+
|
| 48 |
+
combined_features = np.hstack(all_features)
|
| 49 |
+
|
| 50 |
+
return combined_features
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def extract_mfcc_features(y, sr, frame_length, hop_length, num_mfcc=23):
|
| 54 |
+
mfcc_features = extract_overlapping_mfcc(y, sr, num_mfcc, frame_length, hop_length)
|
| 55 |
+
reduced_mfcc_features = reduce_features(mfcc_features)
|
| 56 |
+
return reduced_mfcc_features.T
|
| 57 |
+
|
| 58 |
+
def cepstral_mean_variance_normalization(mfcc):
|
| 59 |
+
mean = np.mean(mfcc, axis=1, keepdims=True)
|
| 60 |
+
std = np.std(mfcc, axis=1, keepdims=True)
|
| 61 |
+
return (mfcc - mean) / (std + 1e-10)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def extract_overlapping_mfcc(chunk, sr, num_mfcc, frame_length, hop_length, include_deltas=True, include_cepstral=True, threshold=1e-5):
|
| 65 |
+
mfcc = librosa.feature.mfcc(y=chunk, sr=sr, n_mfcc=num_mfcc, n_fft=frame_length, hop_length=hop_length)
|
| 66 |
+
if include_cepstral:
|
| 67 |
+
mfcc = cepstral_mean_variance_normalization(mfcc)
|
| 68 |
+
|
| 69 |
+
if include_deltas:
|
| 70 |
+
delta_mfcc = librosa.feature.delta(mfcc)
|
| 71 |
+
delta2_mfcc = librosa.feature.delta(mfcc, order=2)
|
| 72 |
+
combined_mfcc = np.vstack([mfcc, delta_mfcc, delta2_mfcc]) # Stack original MFCCs with deltas
|
| 73 |
+
return combined_mfcc
|
| 74 |
+
else:
|
| 75 |
+
return mfcc
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def reduce_features(features):
|
| 79 |
+
num_frames = features.shape[1]
|
| 80 |
+
paired_frames = features[:, :num_frames // 2 * 2].reshape(features.shape[0], -1, 2)
|
| 81 |
+
reduced_frames = paired_frames.mean(axis=2)
|
| 82 |
+
|
| 83 |
+
if num_frames % 2 == 1:
|
| 84 |
+
last_frame = features[:, -1].reshape(-1, 1)
|
| 85 |
+
reduced_final_features = np.hstack((reduced_frames, last_frame))
|
| 86 |
+
else:
|
| 87 |
+
reduced_final_features = reduced_frames
|
| 88 |
+
|
| 89 |
+
return reduced_final_features
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def extract_overlapping_autocorr(y, sr, frame_length, hop_length, num_autocorr_coeff=187, pad_signal=True, padding_mode="reflect", trim_padded=False):
|
| 94 |
+
if pad_signal:
|
| 95 |
+
pad = frame_length // 2
|
| 96 |
+
y_padded = np.pad(y, pad_width=pad, mode=padding_mode)
|
| 97 |
+
else:
|
| 98 |
+
y_padded = y
|
| 99 |
+
|
| 100 |
+
frames = librosa.util.frame(y_padded, frame_length=frame_length, hop_length=hop_length)
|
| 101 |
+
if pad_signal and trim_padded:
|
| 102 |
+
num_frames = frames.shape[1]
|
| 103 |
+
start_indices = np.arange(num_frames) * hop_length
|
| 104 |
+
valid_idx = np.where((start_indices >= pad) & (start_indices + frame_length <= len(y) + pad))[0]
|
| 105 |
+
frames = frames[:, valid_idx]
|
| 106 |
+
|
| 107 |
+
frames = frames - np.mean(frames, axis=0, keepdims=True)
|
| 108 |
+
hann_window = np.hanning(frame_length)
|
| 109 |
+
windowed_frames = frames * hann_window[:, np.newaxis]
|
| 110 |
+
|
| 111 |
+
autocorr_list = []
|
| 112 |
+
for frame in windowed_frames.T:
|
| 113 |
+
full_corr = np.correlate(frame, frame, mode='full')
|
| 114 |
+
mid = frame_length - 1 # Zero-lag index.
|
| 115 |
+
# Extract `num_autocorr_coeff + 1` to include the first column initially
|
| 116 |
+
wanted = full_corr[mid: mid + num_autocorr_coeff + 1]
|
| 117 |
+
# Normalize by the zero-lag (energy) if nonzero.
|
| 118 |
+
if wanted[0] != 0:
|
| 119 |
+
wanted = wanted / wanted[0]
|
| 120 |
+
autocorr_list.append(wanted)
|
| 121 |
+
|
| 122 |
+
# Convert list to array and transpose so that shape is (num_autocorr_coeff + 1, num_valid_frames)
|
| 123 |
+
autocorr_features = np.array(autocorr_list).T
|
| 124 |
+
# Remove the first coefficient to avoid redundancy
|
| 125 |
+
autocorr_features = autocorr_features[1:, :]
|
| 126 |
+
|
| 127 |
+
autocorr_features = fix_edge_frames_autocorr(autocorr_features)
|
| 128 |
+
|
| 129 |
+
return autocorr_features
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def fix_edge_frames_autocorr(autocorr_features, zero_threshold=1e-7):
|
| 133 |
+
"""If the first or last frame is near all-zero, replicate from adjacent frames."""
|
| 134 |
+
# Check first frame energy
|
| 135 |
+
if np.all(np.abs(autocorr_features[:, 0]) < zero_threshold):
|
| 136 |
+
autocorr_features[:, 0] = autocorr_features[:, 1]
|
| 137 |
+
# Check last frame energy
|
| 138 |
+
if np.all(np.abs(autocorr_features[:, -1]) < zero_threshold):
|
| 139 |
+
autocorr_features[:, -1] = autocorr_features[:, -2]
|
| 140 |
+
return autocorr_features
|
| 141 |
+
|
| 142 |
+
def extract_autocorrelation_features(
|
| 143 |
+
y, sr, frame_length, hop_length, include_deltas=False
|
| 144 |
+
):
|
| 145 |
+
"""
|
| 146 |
+
Extract autocorrelation features, optionally with deltas/delta-deltas,
|
| 147 |
+
then align with the MFCC frame count, reduce, and handle first/last frames.
|
| 148 |
+
"""
|
| 149 |
+
autocorr_features = extract_overlapping_autocorr(
|
| 150 |
+
y, sr, frame_length, hop_length
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if include_deltas:
|
| 154 |
+
autocorr_features = compute_autocorr_with_deltas(autocorr_features)
|
| 155 |
+
|
| 156 |
+
autocorr_features_reduced = reduce_features(autocorr_features)
|
| 157 |
+
|
| 158 |
+
return autocorr_features_reduced.T
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def compute_autocorr_with_deltas(autocorr_base):
|
| 162 |
+
delta_ac = librosa.feature.delta(autocorr_base)
|
| 163 |
+
delta2_ac = librosa.feature.delta(autocorr_base, order=2)
|
| 164 |
+
combined_autocorr = np.vstack([autocorr_base, delta_ac, delta2_ac])
|
| 165 |
+
return combined_autocorr
|
| 166 |
+
|
| 167 |
+
def load_and_preprocess_audio(audio_path, sr=88200):
|
| 168 |
+
y, sr = load_audio(audio_path, sr)
|
| 169 |
+
if sr != 88200:
|
| 170 |
+
y = librosa.resample(y, orig_sr=sr, target_sr=88200)
|
| 171 |
+
sr = 88200
|
| 172 |
+
|
| 173 |
+
max_val = np.max(np.abs(y))
|
| 174 |
+
if max_val > 0:
|
| 175 |
+
y = y / max_val
|
| 176 |
+
|
| 177 |
+
return y, sr
|
| 178 |
+
|
| 179 |
+
def load_audio(audio_path, sr=88200):
|
| 180 |
+
y, sr = librosa.load(audio_path, sr=sr)
|
| 181 |
+
print(f"Loaded audio file '{audio_path}' with sample rate {sr}")
|
| 182 |
+
return y, sr
|
| 183 |
+
|
| 184 |
+
def load_audio_from_bytes(audio_bytes, sr=88200):
|
| 185 |
+
audio_file = io.BytesIO(audio_bytes)
|
| 186 |
+
y, sr = librosa.load(audio_file, sr=sr)
|
| 187 |
+
|
| 188 |
+
max_val = np.max(np.abs(y))
|
| 189 |
+
if max_val > 0:
|
| 190 |
+
y = y / max_val
|
| 191 |
+
|
| 192 |
+
return y, sr
|
| 193 |
+
|
| 194 |
+
def load_audio_file_from_memory(audio_bytes, sr=88200):
|
| 195 |
+
"""Load audio from memory bytes."""
|
| 196 |
+
y, sr = librosa.load(io.BytesIO(audio_bytes), sr=sr)
|
| 197 |
+
print(f"Loaded audio data with sample rate {sr}")
|
| 198 |
+
|
| 199 |
+
max_val = np.max(np.abs(y))
|
| 200 |
+
if max_val > 0:
|
| 201 |
+
y = y / max_val
|
| 202 |
+
|
| 203 |
+
return y, sr
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def load_pcm_audio_from_bytes(audio_bytes, sr=22050, channels=1, sample_width=2):
|
| 209 |
+
"""
|
| 210 |
+
Load raw PCM bytes into a normalized numpy array and upsample to 88200 Hz.
|
| 211 |
+
Assumes little-endian, 16-bit PCM data.
|
| 212 |
+
"""
|
| 213 |
+
# Determine the appropriate numpy dtype.
|
| 214 |
+
if sample_width == 2:
|
| 215 |
+
dtype = np.int16
|
| 216 |
+
max_val = 32768.0
|
| 217 |
+
else:
|
| 218 |
+
raise ValueError("Unsupported sample width")
|
| 219 |
+
|
| 220 |
+
# Convert bytes to numpy array.
|
| 221 |
+
data = np.frombuffer(audio_bytes, dtype=dtype)
|
| 222 |
+
|
| 223 |
+
# If stereo or more channels, reshape accordingly.
|
| 224 |
+
if channels > 1:
|
| 225 |
+
data = data.reshape(-1, channels)
|
| 226 |
+
|
| 227 |
+
# Normalize the data to range [-1, 1]
|
| 228 |
+
y = data.astype(np.float32) / max_val
|
| 229 |
+
|
| 230 |
+
# Upsample the audio from the current sample rate to 88200 Hz.
|
| 231 |
+
target_sr = 88200
|
| 232 |
+
if sr != target_sr:
|
| 233 |
+
# Calculate the number of samples in the resampled signal.
|
| 234 |
+
num_samples = int(len(y) * target_sr / sr)
|
| 235 |
+
if channels > 1:
|
| 236 |
+
# Resample each channel separately.
|
| 237 |
+
y_resampled = np.zeros((num_samples, channels), dtype=np.float32)
|
| 238 |
+
for ch in range(channels):
|
| 239 |
+
y_resampled[:, ch] = scipy.signal.resample(y[:, ch], num_samples)
|
| 240 |
+
else:
|
| 241 |
+
y_resampled = scipy.signal.resample(y, num_samples)
|
| 242 |
+
y = y_resampled
|
| 243 |
+
sr = target_sr
|
| 244 |
+
|
| 245 |
+
return y
|
utils/audio/processing/__pycache__/audio_processing.cpython-311.pyc
ADDED
|
Binary file (7.74 kB). View file
|
|
|
utils/audio/processing/__pycache__/audio_processing.cpython-312.pyc
ADDED
|
Binary file (6.85 kB). View file
|
|
|
utils/audio/processing/audio_processing.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This software is licensed under a **dual-license model**
|
| 2 |
+
# For individuals and businesses earning **under $1M per year**, this software is licensed under the **MIT License**
|
| 3 |
+
# Businesses or organizations with **annual revenue of $1,000,000 or more** must obtain permission to use this software commercially.
|
| 4 |
+
|
| 5 |
+
# audio_processing.py
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch.cuda.amp import autocast
|
| 10 |
+
|
| 11 |
+
def decode_audio_chunk(audio_chunk, model, device, config):
|
| 12 |
+
# Use precision based on config
|
| 13 |
+
use_half_precision = config.get("use_half_precision", True)
|
| 14 |
+
|
| 15 |
+
# Force float16 if half precision is desired; else float32
|
| 16 |
+
dtype = torch.float16 if use_half_precision else torch.float32
|
| 17 |
+
|
| 18 |
+
# Convert audio chunk directly to the desired precision
|
| 19 |
+
src_tensor = torch.tensor(audio_chunk, dtype=dtype).unsqueeze(0).to(device)
|
| 20 |
+
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
if use_half_precision:
|
| 23 |
+
|
| 24 |
+
with autocast(dtype=torch.float16):
|
| 25 |
+
encoder_outputs = model.encoder(src_tensor)
|
| 26 |
+
output_sequence = model.decoder(encoder_outputs)
|
| 27 |
+
else:
|
| 28 |
+
encoder_outputs = model.encoder(src_tensor)
|
| 29 |
+
output_sequence = model.decoder(encoder_outputs)
|
| 30 |
+
|
| 31 |
+
# Convert output tensor back to numpy array
|
| 32 |
+
decoded_outputs = output_sequence.squeeze(0).cpu().numpy()
|
| 33 |
+
return decoded_outputs
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def concatenate_outputs(all_decoded_outputs, num_frames):
|
| 37 |
+
final_decoded_outputs = np.concatenate(all_decoded_outputs, axis=0)
|
| 38 |
+
final_decoded_outputs = final_decoded_outputs[:num_frames]
|
| 39 |
+
return final_decoded_outputs
|
| 40 |
+
|
| 41 |
+
def ensure_2d(final_decoded_outputs):
|
| 42 |
+
if final_decoded_outputs.ndim == 3:
|
| 43 |
+
final_decoded_outputs = final_decoded_outputs.reshape(-1, final_decoded_outputs.shape[-1])
|
| 44 |
+
return final_decoded_outputs
|
| 45 |
+
|
| 46 |
+
def pad_audio_chunk(audio_chunk, frame_length, num_features, pad_mode='replicate'):
|
| 47 |
+
"""
|
| 48 |
+
Pads the audio_chunk to ensure it has a number of frames equal to frame_length.
|
| 49 |
+
|
| 50 |
+
Parameters:
|
| 51 |
+
audio_chunk (np.array): Input audio data with shape (num_frames, num_features).
|
| 52 |
+
frame_length (int): Desired number of frames.
|
| 53 |
+
num_features (int): Number of features per frame.
|
| 54 |
+
pad_mode (str): Type of padding to use. Options are:
|
| 55 |
+
- 'reflect': Pads using reflection.
|
| 56 |
+
- 'replicate': Pads by replicating the last frame.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
np.array: Padded audio_chunk with shape (frame_length, num_features).
|
| 60 |
+
"""
|
| 61 |
+
if audio_chunk.shape[0] < frame_length:
|
| 62 |
+
pad_length = frame_length - audio_chunk.shape[0]
|
| 63 |
+
|
| 64 |
+
if pad_mode == 'reflect':
|
| 65 |
+
# --- Original reflect padding method ---
|
| 66 |
+
padding = np.pad(
|
| 67 |
+
audio_chunk,
|
| 68 |
+
pad_width=((0, pad_length), (0, 0)),
|
| 69 |
+
mode='reflect'
|
| 70 |
+
)
|
| 71 |
+
# Using the last pad_length frames from the reflected padding
|
| 72 |
+
audio_chunk = np.vstack((audio_chunk, padding[-pad_length:, :num_features]))
|
| 73 |
+
|
| 74 |
+
elif pad_mode == 'replicate':
|
| 75 |
+
# --- New replicate padding method ---
|
| 76 |
+
# Replicate the last frame to fill the remaining frames
|
| 77 |
+
last_frame = audio_chunk[-1:] # Select the last frame (shape: (1, num_features))
|
| 78 |
+
replication = np.tile(last_frame, (pad_length, 1)) # Replicate it pad_length times
|
| 79 |
+
audio_chunk = np.vstack((audio_chunk, replication))
|
| 80 |
+
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError(f"Unsupported pad_mode: {pad_mode}. Choose 'reflect' or 'replicate'.")
|
| 83 |
+
|
| 84 |
+
return audio_chunk
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def blend_chunks(chunk1, chunk2, overlap):
|
| 88 |
+
actual_overlap = min(overlap, len(chunk1), len(chunk2))
|
| 89 |
+
if actual_overlap == 0:
|
| 90 |
+
return np.vstack((chunk1, chunk2))
|
| 91 |
+
|
| 92 |
+
blended_chunk = np.copy(chunk1)
|
| 93 |
+
for i in range(actual_overlap):
|
| 94 |
+
alpha = i / actual_overlap
|
| 95 |
+
blended_chunk[-actual_overlap + i] = (1 - alpha) * chunk1[-actual_overlap + i] + alpha * chunk2[i]
|
| 96 |
+
|
| 97 |
+
return np.vstack((blended_chunk, chunk2[actual_overlap:]))
|
| 98 |
+
|
| 99 |
+
def process_audio_features(audio_features, model, device, config):
|
| 100 |
+
# Configuration settings
|
| 101 |
+
frame_length = config['frame_size'] # Number of frames per chunk (e.g., 64)
|
| 102 |
+
overlap = config.get('overlap', 32) # Number of overlapping frames between chunks
|
| 103 |
+
num_features = audio_features.shape[1]
|
| 104 |
+
num_frames = audio_features.shape[0]
|
| 105 |
+
all_decoded_outputs = []
|
| 106 |
+
|
| 107 |
+
# Set model to evaluation mode
|
| 108 |
+
model.eval()
|
| 109 |
+
|
| 110 |
+
# Process chunks with the specified overlap
|
| 111 |
+
start_idx = 0
|
| 112 |
+
while start_idx < num_frames:
|
| 113 |
+
end_idx = min(start_idx + frame_length, num_frames)
|
| 114 |
+
|
| 115 |
+
# Select and pad chunk if needed
|
| 116 |
+
audio_chunk = audio_features[start_idx:end_idx]
|
| 117 |
+
audio_chunk = pad_audio_chunk(audio_chunk, frame_length, num_features)
|
| 118 |
+
|
| 119 |
+
# 🔥 Pass config to dynamically choose precision
|
| 120 |
+
decoded_outputs = decode_audio_chunk(audio_chunk, model, device, config)
|
| 121 |
+
decoded_outputs = decoded_outputs[:end_idx - start_idx]
|
| 122 |
+
|
| 123 |
+
# Blend with the last chunk if it exists
|
| 124 |
+
if all_decoded_outputs:
|
| 125 |
+
last_chunk = all_decoded_outputs.pop()
|
| 126 |
+
blended_chunk = blend_chunks(last_chunk, decoded_outputs, overlap)
|
| 127 |
+
all_decoded_outputs.append(blended_chunk)
|
| 128 |
+
else:
|
| 129 |
+
all_decoded_outputs.append(decoded_outputs)
|
| 130 |
+
|
| 131 |
+
# Move start index forward by (frame_length - overlap)
|
| 132 |
+
start_idx += frame_length - overlap
|
| 133 |
+
|
| 134 |
+
# Process any remaining frames to ensure total frame count matches input
|
| 135 |
+
current_length = sum(len(chunk) for chunk in all_decoded_outputs)
|
| 136 |
+
if current_length < num_frames:
|
| 137 |
+
remaining_frames = num_frames - current_length
|
| 138 |
+
final_chunk_start = num_frames - remaining_frames
|
| 139 |
+
audio_chunk = audio_features[final_chunk_start:num_frames]
|
| 140 |
+
audio_chunk = pad_audio_chunk(audio_chunk, frame_length, num_features)
|
| 141 |
+
decoded_outputs = decode_audio_chunk(audio_chunk, model, device, config)
|
| 142 |
+
all_decoded_outputs.append(decoded_outputs[:remaining_frames])
|
| 143 |
+
|
| 144 |
+
# Concatenate all chunks and trim to the original frame count
|
| 145 |
+
final_decoded_outputs = np.concatenate(all_decoded_outputs, axis=0)[:num_frames]
|
| 146 |
+
|
| 147 |
+
# Normalize or apply any post-processing
|
| 148 |
+
final_decoded_outputs = ensure_2d(final_decoded_outputs)
|
| 149 |
+
final_decoded_outputs[:, :61] /= 100 # Normalize specific columns
|
| 150 |
+
|
| 151 |
+
# Easing effect for smooth start (fades in first 0.2 seconds)
|
| 152 |
+
ease_duration_frames = min(int(0.1 * 60), final_decoded_outputs.shape[0])
|
| 153 |
+
easing_factors = np.linspace(0, 1, ease_duration_frames)[:, None]
|
| 154 |
+
final_decoded_outputs[:ease_duration_frames] *= easing_factors
|
| 155 |
+
|
| 156 |
+
# Zero out unnecessary columns (optional post-processing)
|
| 157 |
+
final_decoded_outputs = zero_columns(final_decoded_outputs)
|
| 158 |
+
|
| 159 |
+
return final_decoded_outputs
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def zero_columns(data):
|
| 163 |
+
columns_to_zero = [0, 1, 2, 3, 4, 7, 8, 9, 10, 11, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60]
|
| 164 |
+
modified_data = np.copy(data)
|
| 165 |
+
modified_data[:, columns_to_zero] = 0
|
| 166 |
+
return modified_data
|
utils/config.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
config = {
|
| 2 |
+
'sr': 88200,
|
| 3 |
+
'frame_rate': 60,
|
| 4 |
+
'hidden_dim': 1024,
|
| 5 |
+
'n_layers': 8,
|
| 6 |
+
'num_heads': 16,
|
| 7 |
+
'dropout': 0.0,
|
| 8 |
+
'output_dim': 68, # if you trained your own, this should also be 61
|
| 9 |
+
'input_dim': 256,
|
| 10 |
+
'frame_size': 128,
|
| 11 |
+
'use_half_precision': False
|
| 12 |
+
}
|
utils/generate_face_shapes.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This software is licensed under a **dual-license model**
|
| 2 |
+
# For individuals and businesses earning **under $1M per year**, this software is licensed under the **MIT License**
|
| 3 |
+
# Businesses or organizations with **annual revenue of $1,000,000 or more** must obtain permission to use this software commercially.
|
| 4 |
+
|
| 5 |
+
# generate_face_shapes.py
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from utils.audio.extraction.extract_features import extract_audio_features
|
| 10 |
+
from utils.audio.processing.audio_processing import process_audio_features
|
| 11 |
+
|
| 12 |
+
def generate_facial_data_from_bytes(audio_bytes, model, device, config):
|
| 13 |
+
|
| 14 |
+
audio_features, y = extract_audio_features(audio_bytes, from_bytes=True)
|
| 15 |
+
|
| 16 |
+
if audio_features is None or y is None:
|
| 17 |
+
return [], np.array([])
|
| 18 |
+
|
| 19 |
+
final_decoded_outputs = process_audio_features(audio_features, model, device, config)
|
| 20 |
+
|
| 21 |
+
return final_decoded_outputs
|
| 22 |
+
|
utils/model/__pycache__/model.cpython-311.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
utils/model/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|
utils/model/model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dcf989405316b25c3d2cdfe389e41b50326284bd520126979eb602978ba842d3
|
| 3 |
+
size 942037770
|
utils/model/model.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This software is licensed under a **dual-license model**
|
| 2 |
+
# For individuals and businesses earning **under $1M per year**, this software is licensed under the **MIT License**
|
| 3 |
+
# Businesses or organizations with **annual revenue of $1,000,000 or more** must obtain permission to use this software commercially.
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
def load_model(model_path, config, device):
|
| 10 |
+
device = torch.device(device)
|
| 11 |
+
|
| 12 |
+
# Retrieve the half precision setting from the config
|
| 13 |
+
use_half_precision = config.get('use_half_precision', True)
|
| 14 |
+
|
| 15 |
+
# 🔥 NEW: Check for CUDA and cuDNN availability.
|
| 16 |
+
# If half precision is requested but CUDA or cuDNN are not available,
|
| 17 |
+
# fall back to full precision and update the config.
|
| 18 |
+
if use_half_precision:
|
| 19 |
+
if not (device.type == 'cuda' and torch.cuda.is_available() and torch.backends.cudnn.enabled):
|
| 20 |
+
print("⚠ Half-precision requested but CUDA or cuDNN not available. Falling back to full precision.")
|
| 21 |
+
use_half_precision = False
|
| 22 |
+
config['use_half_precision'] = False # Update config to reflect the fallback
|
| 23 |
+
|
| 24 |
+
hidden_dim = config['hidden_dim']
|
| 25 |
+
n_layers = config['n_layers']
|
| 26 |
+
num_heads = config['num_heads']
|
| 27 |
+
|
| 28 |
+
encoder = Encoder(config['input_dim'], hidden_dim, n_layers, num_heads)
|
| 29 |
+
decoder = Decoder(config['output_dim'], hidden_dim, n_layers, num_heads)
|
| 30 |
+
model = Seq2Seq(encoder, decoder, device).to(device)
|
| 31 |
+
|
| 32 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 33 |
+
model.load_state_dict(state_dict, strict=True)
|
| 34 |
+
|
| 35 |
+
# Convert the model to half precision if applicable
|
| 36 |
+
if use_half_precision and device.type == 'cuda':
|
| 37 |
+
model = model.to(torch.float16)
|
| 38 |
+
print("⚡ Model converted to float16 (half-precision).")
|
| 39 |
+
else:
|
| 40 |
+
print("🚫 Half-precision not applied (CPU or unsupported GPU or False set in config).")
|
| 41 |
+
|
| 42 |
+
model.eval()
|
| 43 |
+
return model
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# -------------------------------------------------------------------------------------------
|
| 48 |
+
# Seq2Seq Model
|
| 49 |
+
# -------------------------------------------------------------------------------------------
|
| 50 |
+
class Seq2Seq(nn.Module):
|
| 51 |
+
def __init__(self, encoder, decoder, device):
|
| 52 |
+
super(Seq2Seq, self).__init__()
|
| 53 |
+
self.encoder = encoder
|
| 54 |
+
self.decoder = decoder
|
| 55 |
+
self.device = device
|
| 56 |
+
|
| 57 |
+
def forward(self, src):
|
| 58 |
+
encoder_outputs = self.encoder(src)
|
| 59 |
+
output = self.decoder(encoder_outputs)
|
| 60 |
+
return output
|
| 61 |
+
|
| 62 |
+
# -------------------------------------------------------------------------------------------
|
| 63 |
+
# Rotary Positional Embedding (RoPE) for Local Attention
|
| 64 |
+
# -------------------------------------------------------------------------------------------
|
| 65 |
+
def apply_rope_qk(q, k, use_local_positional_encoding=True):
|
| 66 |
+
if not use_local_positional_encoding:
|
| 67 |
+
return q, k # Return unmodified q, k if RoPE is disabled
|
| 68 |
+
|
| 69 |
+
batch_size, num_heads, seq_len, head_dim = q.size()
|
| 70 |
+
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
|
| 71 |
+
|
| 72 |
+
position = torch.arange(seq_len, dtype=torch.float, device=q.device).unsqueeze(1) # (seq_len, 1)
|
| 73 |
+
dim_indices = torch.arange(0, head_dim, 2, dtype=torch.float, device=q.device) # (head_dim // 2)
|
| 74 |
+
div_term = torch.exp(-torch.log(torch.tensor(10000.0)) * dim_indices / head_dim)
|
| 75 |
+
|
| 76 |
+
angle = position * div_term # (seq_len, head_dim // 2)
|
| 77 |
+
sin = torch.sin(angle).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, head_dim // 2)
|
| 78 |
+
cos = torch.cos(angle).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, head_dim // 2)
|
| 79 |
+
|
| 80 |
+
def rope_transform(x):
|
| 81 |
+
x1, x2 = x[..., ::2], x[..., 1::2] # Split into even and odd parts
|
| 82 |
+
x_rope_even = x1 * cos - x2 * sin
|
| 83 |
+
x_rope_odd = x1 * sin + x2 * cos
|
| 84 |
+
return torch.stack([x_rope_even, x_rope_odd], dim=-1).flatten(-2)
|
| 85 |
+
|
| 86 |
+
q = rope_transform(q)
|
| 87 |
+
k = rope_transform(k)
|
| 88 |
+
return q, k
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# -------------------------------------------------------------------------------------------
|
| 92 |
+
# Multi-Head Attention with RoPE
|
| 93 |
+
# -------------------------------------------------------------------------------------------
|
| 94 |
+
class MultiHeadAttention(nn.Module):
|
| 95 |
+
def __init__(self, hidden_dim, num_heads, dropout=0.0):
|
| 96 |
+
super(MultiHeadAttention, self).__init__()
|
| 97 |
+
assert hidden_dim % num_heads == 0, "Hidden dimension must be divisible by the number of heads"
|
| 98 |
+
self.num_heads = num_heads
|
| 99 |
+
self.head_dim = hidden_dim // num_heads
|
| 100 |
+
self.scaling = self.head_dim ** -0.5
|
| 101 |
+
|
| 102 |
+
self.q_linear = nn.Linear(hidden_dim, hidden_dim)
|
| 103 |
+
self.k_linear = nn.Linear(hidden_dim, hidden_dim)
|
| 104 |
+
self.v_linear = nn.Linear(hidden_dim, hidden_dim)
|
| 105 |
+
self.out_linear = nn.Linear(hidden_dim, hidden_dim)
|
| 106 |
+
|
| 107 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 108 |
+
self.resid_dropout = nn.Dropout(dropout)
|
| 109 |
+
self.dropout = dropout
|
| 110 |
+
|
| 111 |
+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
| 112 |
+
if not self.flash:
|
| 113 |
+
print("WARNING: Flash Attention requires PyTorch >= 2.0")
|
| 114 |
+
|
| 115 |
+
def forward(self, query, key, value, mask=None):
|
| 116 |
+
batch_size = query.size(0)
|
| 117 |
+
|
| 118 |
+
query = self.q_linear(query)
|
| 119 |
+
key = self.k_linear(key)
|
| 120 |
+
value = self.v_linear(value)
|
| 121 |
+
|
| 122 |
+
# Reshape to (B, H, L, D)
|
| 123 |
+
query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 124 |
+
key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 125 |
+
value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 126 |
+
|
| 127 |
+
# Apply RoPE to Q and K (if enabled)
|
| 128 |
+
query, key = apply_rope_qk(query, key)
|
| 129 |
+
|
| 130 |
+
if self.flash:
|
| 131 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 132 |
+
query, key, value, attn_mask=mask, dropout_p=self.dropout if self.training else 0)
|
| 133 |
+
attn_weights = None
|
| 134 |
+
else:
|
| 135 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) * self.scaling
|
| 136 |
+
if mask is not None:
|
| 137 |
+
scores = scores.masked_fill(mask == 0, float('-inf'))
|
| 138 |
+
attn_weights = F.softmax(scores, dim=-1)
|
| 139 |
+
attn_weights = self.attn_dropout(attn_weights)
|
| 140 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 141 |
+
|
| 142 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
|
| 143 |
+
output = self.out_linear(attn_output)
|
| 144 |
+
output = self.resid_dropout(output)
|
| 145 |
+
|
| 146 |
+
return output, attn_weights
|
| 147 |
+
|
| 148 |
+
# -------------------------------------------------------------------------------------------
|
| 149 |
+
# Feed-Forward Network
|
| 150 |
+
# -------------------------------------------------------------------------------------------
|
| 151 |
+
class FeedForwardNetwork(nn.Module):
|
| 152 |
+
def __init__(self, hidden_dim, dim_feedforward=2048, dropout=0.0):
|
| 153 |
+
super(FeedForwardNetwork, self).__init__()
|
| 154 |
+
self.linear1 = nn.Linear(hidden_dim, dim_feedforward)
|
| 155 |
+
self.dropout = nn.Dropout(dropout)
|
| 156 |
+
self.linear2 = nn.Linear(dim_feedforward, hidden_dim)
|
| 157 |
+
|
| 158 |
+
def forward(self, x):
|
| 159 |
+
x = self.linear1(x)
|
| 160 |
+
x = F.relu(x)
|
| 161 |
+
x = self.dropout(x)
|
| 162 |
+
x = self.linear2(x)
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
# -------------------------------------------------------------------------------------------
|
| 166 |
+
# Custom Transformer Encoder/Decoder
|
| 167 |
+
# -------------------------------------------------------------------------------------------
|
| 168 |
+
class CustomTransformerEncoderLayer(nn.Module):
|
| 169 |
+
def __init__(self, hidden_dim, num_heads, dropout=0.0):
|
| 170 |
+
super(CustomTransformerEncoderLayer, self).__init__()
|
| 171 |
+
self.self_attn = MultiHeadAttention(hidden_dim, num_heads, dropout)
|
| 172 |
+
self.ffn = FeedForwardNetwork(hidden_dim, 4 * hidden_dim, dropout)
|
| 173 |
+
self.norm1 = nn.LayerNorm(hidden_dim)
|
| 174 |
+
self.norm2 = nn.LayerNorm(hidden_dim)
|
| 175 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 176 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 177 |
+
|
| 178 |
+
def forward(self, src, mask=None):
|
| 179 |
+
src2, _ = self.self_attn(src, src, src, mask)
|
| 180 |
+
src = src + self.dropout1(src2)
|
| 181 |
+
src = self.norm1(src)
|
| 182 |
+
|
| 183 |
+
src2 = self.ffn(src)
|
| 184 |
+
src = src + self.dropout2(src2)
|
| 185 |
+
src = self.norm2(src)
|
| 186 |
+
return src
|
| 187 |
+
|
| 188 |
+
class CustomTransformerDecoderLayer(nn.Module):
|
| 189 |
+
def __init__(self, hidden_dim, num_heads, dropout=0.0):
|
| 190 |
+
super(CustomTransformerDecoderLayer, self).__init__()
|
| 191 |
+
self.self_attn = MultiHeadAttention(hidden_dim, num_heads, dropout)
|
| 192 |
+
self.multihead_attn = MultiHeadAttention(hidden_dim, num_heads, dropout)
|
| 193 |
+
self.ffn = FeedForwardNetwork(hidden_dim, 4 * hidden_dim, dropout)
|
| 194 |
+
self.norm1 = nn.LayerNorm(hidden_dim)
|
| 195 |
+
self.norm2 = nn.LayerNorm(hidden_dim)
|
| 196 |
+
self.norm3 = nn.LayerNorm(hidden_dim)
|
| 197 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 198 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 199 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 200 |
+
|
| 201 |
+
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
|
| 202 |
+
tgt2, _ = self.self_attn(tgt, tgt, tgt, tgt_mask)
|
| 203 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 204 |
+
tgt = self.norm1(tgt)
|
| 205 |
+
|
| 206 |
+
tgt2, _ = self.multihead_attn(tgt, memory, memory, memory_mask)
|
| 207 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 208 |
+
tgt = self.norm2(tgt)
|
| 209 |
+
|
| 210 |
+
tgt2 = self.ffn(tgt)
|
| 211 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 212 |
+
tgt = self.norm3(tgt)
|
| 213 |
+
return tgt
|
| 214 |
+
|
| 215 |
+
# -------------------------------------------------------------------------------------------
|
| 216 |
+
# Encoder
|
| 217 |
+
# -------------------------------------------------------------------------------------------
|
| 218 |
+
class Encoder(nn.Module):
|
| 219 |
+
def __init__(self, input_dim, hidden_dim, n_layers, num_heads, dropout=0.0, use_norm=True):
|
| 220 |
+
super(Encoder, self).__init__()
|
| 221 |
+
self.embedding = nn.Linear(input_dim, hidden_dim)
|
| 222 |
+
# CHANGED: Removed global positional encoding as RoPE is used in MHA.
|
| 223 |
+
self.transformer_encoder = nn.ModuleList([
|
| 224 |
+
CustomTransformerEncoderLayer(hidden_dim, num_heads, dropout) for _ in range(n_layers)
|
| 225 |
+
])
|
| 226 |
+
self.layer_norm = nn.LayerNorm(hidden_dim) if use_norm else None
|
| 227 |
+
|
| 228 |
+
def forward(self, x):
|
| 229 |
+
x = self.embedding(x)
|
| 230 |
+
# CHANGED: Global positional encoding removed.
|
| 231 |
+
for layer in self.transformer_encoder:
|
| 232 |
+
x = layer(x)
|
| 233 |
+
if self.layer_norm:
|
| 234 |
+
x = self.layer_norm(x)
|
| 235 |
+
return x
|
| 236 |
+
|
| 237 |
+
# -------------------------------------------------------------------------------------------
|
| 238 |
+
# Decoder
|
| 239 |
+
# -------------------------------------------------------------------------------------------
|
| 240 |
+
class Decoder(nn.Module):
|
| 241 |
+
def __init__(self, output_dim, hidden_dim, n_layers, num_heads, dropout=0.0, use_norm=True):
|
| 242 |
+
super(Decoder, self).__init__()
|
| 243 |
+
self.transformer_decoder = nn.ModuleList([
|
| 244 |
+
CustomTransformerDecoderLayer(hidden_dim, num_heads, dropout) for _ in range(n_layers)
|
| 245 |
+
])
|
| 246 |
+
self.fc_output = nn.Linear(hidden_dim, output_dim)
|
| 247 |
+
self.layer_norm = nn.LayerNorm(hidden_dim) if use_norm else None
|
| 248 |
+
|
| 249 |
+
def forward(self, encoder_outputs):
|
| 250 |
+
x = encoder_outputs
|
| 251 |
+
for layer in self.transformer_decoder:
|
| 252 |
+
x = layer(x, encoder_outputs)
|
| 253 |
+
if self.layer_norm:
|
| 254 |
+
x = self.layer_norm(x)
|
| 255 |
+
return self.fc_output(x)
|
| 256 |
+
|