Spaces:
Build error
Build error
| import os | |
| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModel | |
| class CLIPElectraFusion(nn.Module): | |
| """ | |
| Multimodal fusion model combining CLIP (image) and ELECTRA (text) features | |
| """ | |
| def __init__(self, clip_model, electra_model, | |
| fusion_text_dim=256, | |
| num_classes=2, freeze_encoders=True): | |
| super().__init__() | |
| self.clip = clip_model | |
| self.electra = electra_model | |
| if freeze_encoders: | |
| for p in self.clip.parameters(): | |
| p.requires_grad = False | |
| for p in self.electra.parameters(): | |
| p.requires_grad = False | |
| # CLIP dimensi output image feature | |
| self.img_dim = clip_model.config.projection_dim # 512 | |
| # ELECTRA dimensi output text feature | |
| electra_hidden_dim = electra_model.config.hidden_size # e.g. 768 | |
| # Projection layer only text: (electra_hidden_dim -> fusion_text_dim : 256) | |
| self.project_text = nn.Sequential( | |
| nn.Linear(electra_hidden_dim, fusion_text_dim), | |
| nn.GELU(), | |
| nn.LayerNorm(fusion_text_dim) | |
| ) | |
| # Fusion dimension: 512 (dimensi CLIP) + 256 (text projected) = 768 | |
| self.fusion_dim = self.img_dim + fusion_text_dim | |
| # Positional embedding untuk 2 tokens (image + text) | |
| self.pos_embedding = nn.Parameter(torch.randn(1, 2, self.fusion_dim)) | |
| # 2-layer Transformer untuk fusion | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=self.fusion_dim, | |
| nhead=8, | |
| dim_feedforward=self.fusion_dim * 4, | |
| dropout=0.1, | |
| batch_first=True | |
| ) | |
| self.fusion_transformer = nn.TransformerEncoder(encoder_layer, num_layers=2) | |
| # 3-layer MLP classifier | |
| self.classifier = nn.Sequential( | |
| nn.Linear(self.fusion_dim, self.fusion_dim // 2), # 768 -> 384 | |
| nn.GELU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(self.fusion_dim // 2, self.fusion_dim // 4), # 384 -> 192 | |
| nn.GELU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(self.fusion_dim // 4, num_classes) # 192 -> num_classes | |
| ) | |
| def forward(self, pixel_values, input_ids, attention_mask): | |
| # Extract image features from CLIP | |
| img_output = self.clip.get_image_features(pixel_values) | |
| if hasattr(img_output, 'pooler_output'): | |
| img_feats = img_output.pooler_output | |
| elif isinstance(img_output, torch.Tensor): | |
| img_feats = img_output | |
| else: | |
| img_feats = img_output[0] if isinstance(img_output, (tuple, list)) else img_output | |
| # Normalisasi (L2 normalization) | |
| img_proj = img_feats / (img_feats.norm(dim=-1, keepdim=True) + 1e-10) | |
| # Extract text features from ELECTRA | |
| txt_out = self.electra(input_ids=input_ids, attention_mask=attention_mask) | |
| last_hidden = txt_out.last_hidden_state | |
| # Mean pooling | |
| attn = attention_mask.unsqueeze(-1).float() | |
| sum_emb = (last_hidden * attn).sum(dim=1) | |
| sum_mask = attn.sum(dim=1).clamp(min=1e-9) | |
| text_emb = sum_emb / sum_mask | |
| # Project text dari 768 -> 256 dimensi | |
| text_proj = self.project_text(text_emb) | |
| # Urutan Positional Embedding + Fusion | |
| img_token = torch.cat([img_proj, torch.zeros_like(text_proj)], dim=-1) | |
| text_token = torch.cat([torch.zeros_like(img_proj), text_proj], dim=-1) | |
| # Stack tokens and add positional embedding | |
| tokens = torch.stack([img_token, text_token], dim=1) | |
| tokens = tokens + self.pos_embedding | |
| # Fusion with 2-layer Transformer | |
| fused_tokens = self.fusion_transformer(tokens) | |
| # Use first token (image token) for classification | |
| fused_rep = fused_tokens[:, 0, :] | |
| # 3-layer MLP classifier | |
| logits = self.classifier(fused_rep) | |
| return logits, img_proj, text_proj | |
| # Global variables for model and processors | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = None | |
| clip_processor = None | |
| electra_tokenizer = None | |
| # Label mapping | |
| LABEL_NAMES = {0: 'NON-SELF-HARM', 1: 'SELF-HARM'} | |
| def load_model(): | |
| """Load the trained model and processors""" | |
| global model, clip_processor, electra_tokenizer | |
| print("Loading CLIP model...") | |
| clip_model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32') | |
| clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32') | |
| print("Loading ELECTRA model...") | |
| electra_model = AutoModel.from_pretrained('sentinet/suicidality') | |
| electra_tokenizer = AutoTokenizer.from_pretrained('sentinet/suicidality') | |
| print("Initializing fusion model...") | |
| model = CLIPElectraFusion( | |
| clip_model=clip_model, | |
| electra_model=electra_model, | |
| fusion_text_dim=256, | |
| num_classes=2, | |
| freeze_encoders=True | |
| ) | |
| # Load trained weights from Hugging Face Model Hub | |
| # Change this to your model repo: "username/model-name" | |
| model_repo = "elsaelisa09/meme-self-harm-detection-model" | |
| checkpoint_filename = "bestmodel_ArsitekturA_Bd.pth" | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| print(f"Downloading model from {model_repo}...") | |
| checkpoint_path = hf_hub_download( | |
| repo_id=model_repo, | |
| filename=checkpoint_filename, | |
| repo_type="model" | |
| ) | |
| print(f"Loading checkpoint from {checkpoint_path}...") | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model.load_state_dict(checkpoint) | |
| print("Model loaded successfully from Hugging Face Hub!") | |
| except Exception as e: | |
| print(f"Error loading model from Hub: {e}") | |
| print("Trying to load from local file...") | |
| checkpoint_path = checkpoint_filename | |
| if os.path.exists(checkpoint_path): | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model.load_state_dict(checkpoint) | |
| print("Model loaded from local file!") | |
| else: | |
| print(f"Warning: Model file not found. Using untrained model.") | |
| model.to(device) | |
| model.eval() | |
| print(f"Model loaded on {device}") | |
| def predict(image, text): | |
| """ | |
| Perform prediction on image and text inputs | |
| Args: | |
| image: PIL Image | |
| text: str - Text extracted from the image (OCR text) | |
| Returns: | |
| tuple: (predicted_label, confidence_dict) | |
| """ | |
| if model is None: | |
| return "Model not loaded", {} | |
| if image is None: | |
| return "Please upload an image", {} | |
| # Default text if empty | |
| if not text or text.strip() == "": | |
| text = "" | |
| # Preprocess image | |
| clip_inputs = clip_processor(images=image, return_tensors='pt') | |
| pixel_values = clip_inputs['pixel_values'].to(device) | |
| # Preprocess text | |
| enc = electra_tokenizer( | |
| text, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=128, | |
| return_tensors='pt' | |
| ) | |
| input_ids = enc['input_ids'].to(device) | |
| attention_mask = enc['attention_mask'].to(device) | |
| # Perform inference | |
| with torch.no_grad(): | |
| logits, _, _ = model(pixel_values, input_ids, attention_mask) | |
| probabilities = torch.softmax(logits, dim=-1) | |
| predicted_class = torch.argmax(probabilities, dim=-1).item() | |
| confidence = probabilities[0, predicted_class].item() | |
| # Prepare result | |
| predicted_label = LABEL_NAMES[predicted_class] | |
| confidence_dict = { | |
| LABEL_NAMES[0]: f"{probabilities[0, 0].item():.2%}", | |
| LABEL_NAMES[1]: f"{probabilities[0, 1].item():.2%}" | |
| } | |
| result = f"**Predicted Label:** {predicted_label}\n\n**Confidence:** {confidence:.2%}" | |
| return result, confidence_dict | |
| # Initialize model on startup | |
| print("Initializing model...") | |
| load_model() | |
| # Create Gradio interface | |
| with gr.Blocks(title="Self-Harm Detection - Multimodal Model") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🔍 Self-Harm Content Detection | |
| ### Multimodal Image + Text Classification Model | |
| This model analyzes both **image** and **text** content to detect potential self-harm content. | |
| **How to use:** | |
| 1. Upload an image | |
| 2. Enter the text visible in the image (OCR text) | |
| 3. Click "Predict" to see the results | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image( | |
| label="Upload Image", | |
| type="pil", | |
| height=300 | |
| ) | |
| text_input = gr.Textbox( | |
| label="Text from Image (OCR)", | |
| placeholder="Enter the text visible in the image...", | |
| lines=3 | |
| ) | |
| predict_btn = gr.Button("🔍 Predict", variant="primary", size="lg") | |
| with gr.Column(): | |
| output_text = gr.Markdown(label="Prediction Result") | |
| output_confidence = gr.Label(label="Confidence Scores", num_top_classes=2) | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[image_input, text_input], | |
| outputs=[output_text, output_confidence] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### ℹ️ About the Model | |
| **Architecture:** CLIP (Image Encoder) + ELECTRA (Text Encoder) + Transformer Fusion | |
| **Classes:** | |
| - **NON-SELF-HARM**: Content that does not contain self-harm indicators | |
| - **SELF-HARM**: Content that may contain self-harm related material | |
| **Note:** This model is designed for research purposes. Always consult with mental health professionals | |
| for serious concerns. | |
| """ | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |