elsaelisa09's picture
Upload 6 files
f9b5720 verified
import os
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModel
class CLIPElectraFusion(nn.Module):
"""
Multimodal fusion model combining CLIP (image) and ELECTRA (text) features
"""
def __init__(self, clip_model, electra_model,
fusion_text_dim=256,
num_classes=2, freeze_encoders=True):
super().__init__()
self.clip = clip_model
self.electra = electra_model
if freeze_encoders:
for p in self.clip.parameters():
p.requires_grad = False
for p in self.electra.parameters():
p.requires_grad = False
# CLIP dimensi output image feature
self.img_dim = clip_model.config.projection_dim # 512
# ELECTRA dimensi output text feature
electra_hidden_dim = electra_model.config.hidden_size # e.g. 768
# Projection layer only text: (electra_hidden_dim -> fusion_text_dim : 256)
self.project_text = nn.Sequential(
nn.Linear(electra_hidden_dim, fusion_text_dim),
nn.GELU(),
nn.LayerNorm(fusion_text_dim)
)
# Fusion dimension: 512 (dimensi CLIP) + 256 (text projected) = 768
self.fusion_dim = self.img_dim + fusion_text_dim
# Positional embedding untuk 2 tokens (image + text)
self.pos_embedding = nn.Parameter(torch.randn(1, 2, self.fusion_dim))
# 2-layer Transformer untuk fusion
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.fusion_dim,
nhead=8,
dim_feedforward=self.fusion_dim * 4,
dropout=0.1,
batch_first=True
)
self.fusion_transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
# 3-layer MLP classifier
self.classifier = nn.Sequential(
nn.Linear(self.fusion_dim, self.fusion_dim // 2), # 768 -> 384
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(self.fusion_dim // 2, self.fusion_dim // 4), # 384 -> 192
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(self.fusion_dim // 4, num_classes) # 192 -> num_classes
)
def forward(self, pixel_values, input_ids, attention_mask):
# Extract image features from CLIP
img_output = self.clip.get_image_features(pixel_values)
if hasattr(img_output, 'pooler_output'):
img_feats = img_output.pooler_output
elif isinstance(img_output, torch.Tensor):
img_feats = img_output
else:
img_feats = img_output[0] if isinstance(img_output, (tuple, list)) else img_output
# Normalisasi (L2 normalization)
img_proj = img_feats / (img_feats.norm(dim=-1, keepdim=True) + 1e-10)
# Extract text features from ELECTRA
txt_out = self.electra(input_ids=input_ids, attention_mask=attention_mask)
last_hidden = txt_out.last_hidden_state
# Mean pooling
attn = attention_mask.unsqueeze(-1).float()
sum_emb = (last_hidden * attn).sum(dim=1)
sum_mask = attn.sum(dim=1).clamp(min=1e-9)
text_emb = sum_emb / sum_mask
# Project text dari 768 -> 256 dimensi
text_proj = self.project_text(text_emb)
# Urutan Positional Embedding + Fusion
img_token = torch.cat([img_proj, torch.zeros_like(text_proj)], dim=-1)
text_token = torch.cat([torch.zeros_like(img_proj), text_proj], dim=-1)
# Stack tokens and add positional embedding
tokens = torch.stack([img_token, text_token], dim=1)
tokens = tokens + self.pos_embedding
# Fusion with 2-layer Transformer
fused_tokens = self.fusion_transformer(tokens)
# Use first token (image token) for classification
fused_rep = fused_tokens[:, 0, :]
# 3-layer MLP classifier
logits = self.classifier(fused_rep)
return logits, img_proj, text_proj
# Global variables for model and processors
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = None
clip_processor = None
electra_tokenizer = None
# Label mapping
LABEL_NAMES = {0: 'NON-SELF-HARM', 1: 'SELF-HARM'}
def load_model():
"""Load the trained model and processors"""
global model, clip_processor, electra_tokenizer
print("Loading CLIP model...")
clip_model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
print("Loading ELECTRA model...")
electra_model = AutoModel.from_pretrained('sentinet/suicidality')
electra_tokenizer = AutoTokenizer.from_pretrained('sentinet/suicidality')
print("Initializing fusion model...")
model = CLIPElectraFusion(
clip_model=clip_model,
electra_model=electra_model,
fusion_text_dim=256,
num_classes=2,
freeze_encoders=True
)
# Load trained weights from Hugging Face Model Hub
# Change this to your model repo: "username/model-name"
model_repo = "elsaelisa09/meme-self-harm-detection-model"
checkpoint_filename = "bestmodel_ArsitekturA_Bd.pth"
try:
from huggingface_hub import hf_hub_download
print(f"Downloading model from {model_repo}...")
checkpoint_path = hf_hub_download(
repo_id=model_repo,
filename=checkpoint_filename,
repo_type="model"
)
print(f"Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint)
print("Model loaded successfully from Hugging Face Hub!")
except Exception as e:
print(f"Error loading model from Hub: {e}")
print("Trying to load from local file...")
checkpoint_path = checkpoint_filename
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint)
print("Model loaded from local file!")
else:
print(f"Warning: Model file not found. Using untrained model.")
model.to(device)
model.eval()
print(f"Model loaded on {device}")
def predict(image, text):
"""
Perform prediction on image and text inputs
Args:
image: PIL Image
text: str - Text extracted from the image (OCR text)
Returns:
tuple: (predicted_label, confidence_dict)
"""
if model is None:
return "Model not loaded", {}
if image is None:
return "Please upload an image", {}
# Default text if empty
if not text or text.strip() == "":
text = ""
# Preprocess image
clip_inputs = clip_processor(images=image, return_tensors='pt')
pixel_values = clip_inputs['pixel_values'].to(device)
# Preprocess text
enc = electra_tokenizer(
text,
truncation=True,
padding='max_length',
max_length=128,
return_tensors='pt'
)
input_ids = enc['input_ids'].to(device)
attention_mask = enc['attention_mask'].to(device)
# Perform inference
with torch.no_grad():
logits, _, _ = model(pixel_values, input_ids, attention_mask)
probabilities = torch.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
confidence = probabilities[0, predicted_class].item()
# Prepare result
predicted_label = LABEL_NAMES[predicted_class]
confidence_dict = {
LABEL_NAMES[0]: f"{probabilities[0, 0].item():.2%}",
LABEL_NAMES[1]: f"{probabilities[0, 1].item():.2%}"
}
result = f"**Predicted Label:** {predicted_label}\n\n**Confidence:** {confidence:.2%}"
return result, confidence_dict
# Initialize model on startup
print("Initializing model...")
load_model()
# Create Gradio interface
with gr.Blocks(title="Self-Harm Detection - Multimodal Model") as demo:
gr.Markdown(
"""
# 🔍 Self-Harm Content Detection
### Multimodal Image + Text Classification Model
This model analyzes both **image** and **text** content to detect potential self-harm content.
**How to use:**
1. Upload an image
2. Enter the text visible in the image (OCR text)
3. Click "Predict" to see the results
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(
label="Upload Image",
type="pil",
height=300
)
text_input = gr.Textbox(
label="Text from Image (OCR)",
placeholder="Enter the text visible in the image...",
lines=3
)
predict_btn = gr.Button("🔍 Predict", variant="primary", size="lg")
with gr.Column():
output_text = gr.Markdown(label="Prediction Result")
output_confidence = gr.Label(label="Confidence Scores", num_top_classes=2)
predict_btn.click(
fn=predict,
inputs=[image_input, text_input],
outputs=[output_text, output_confidence]
)
gr.Markdown(
"""
---
### ℹ️ About the Model
**Architecture:** CLIP (Image Encoder) + ELECTRA (Text Encoder) + Transformer Fusion
**Classes:**
- **NON-SELF-HARM**: Content that does not contain self-harm indicators
- **SELF-HARM**: Content that may contain self-harm related material
**Note:** This model is designed for research purposes. Always consult with mental health professionals
for serious concerns.
"""
)
# Launch the app
if __name__ == "__main__":
demo.launch()