|
|
import torch |
|
|
import torch.nn as nn |
|
|
import faiss |
|
|
import math |
|
|
|
|
|
from model_args import ModelArgs |
|
|
from transformer import Transformer |
|
|
from tranny import FAISSRetriever |
|
|
|
|
|
|
|
|
class CombinedMultiModalTransformer(nn.Module): |
|
|
def __init__(self, args: ModelArgs, knowledge_base: faiss.Index): |
|
|
super(CombinedMultiModalTransformer, self).__init__() |
|
|
self.args = args |
|
|
self.transformer = Transformer(args) |
|
|
|
|
|
|
|
|
self.audio_encoder = nn.Sequential( |
|
|
nn.Conv1d(256, 256, kernel_size=11, stride=2, padding='same'), |
|
|
nn.ReLU(), |
|
|
nn.Conv1d(256, args.dim, kernel_size=11, stride=2, padding='same'), |
|
|
nn.ReLU() |
|
|
) |
|
|
|
|
|
self.image_encoder = nn.Sequential( |
|
|
|
|
|
nn.AdaptiveAvgPool2d((1, 1)), |
|
|
nn.Flatten(), |
|
|
nn.Linear(2048, args.dim) |
|
|
) |
|
|
|
|
|
|
|
|
self.pitch_embedding = nn.Embedding(128, args.dim) |
|
|
self.duration_embedding = nn.Embedding(32, args.dim) |
|
|
self.velocity_embedding = nn.Embedding(128, args.dim) |
|
|
|
|
|
|
|
|
self.anomaly_detector = nn.Sequential( |
|
|
nn.Linear(args.dim, args.dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(args.dim, 1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
|
|
|
self.knowledge_base = FAISSRetriever(knowledge_base) |
|
|
self.query_encoder = nn.Sequential( |
|
|
nn.Linear(args.dim, args.dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(args.dim, args.dim) |
|
|
) |
|
|
|
|
|
def forward(self, inputs, task): |
|
|
if task == 'text_generation': |
|
|
|
|
|
query_embedding = self.query_encoder(self.transformer.embed(inputs)) |
|
|
retrieved_docs = self.knowledge_base.search(query_embedding, k=5) |
|
|
|
|
|
|
|
|
inputs = torch.cat([retrieved_docs, inputs], dim=1) |
|
|
|
|
|
|
|
|
logits = self.transformer(inputs) |
|
|
return logits |
|
|
|
|
|
elif task == 'speech_recognition': |
|
|
x = self.audio_encoder(inputs) |
|
|
logits = self.transformer(x) |
|
|
return logits |
|
|
|
|
|
elif task == 'image_captioning': |
|
|
image_features = self.image_encoder(inputs) |
|
|
logits = self.transformer(image_features) |
|
|
return logits |
|
|
|
|
|
elif task == 'music_generation': |
|
|
pitch, duration, velocity = inputs |
|
|
x = self.pitch_embedding(pitch) + self.duration_embedding(duration) + self.velocity_embedding(velocity) |
|
|
logits = self.transformer(x) |
|
|
return logits |
|
|
|
|
|
elif task == 'anomaly_detection': |
|
|
x = self.transformer.embed(inputs) |
|
|
anomaly_scores = self.anomaly_detector(x) |
|
|
return anomaly_scores |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown task: {task}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = ModelArgs() |
|
|
|
|
|
knowledge_base = faiss.IndexFlatL2(args.dim) |
|
|
model = CombinedMultiModalTransformer(args, knowledge_base) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|