Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from PIL import Image | |
| import streamlit as st | |
| from transformers import ( | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| BlipForConditionalGeneration, | |
| BlipProcessor, | |
| ) | |
| def load_caption_model(): | |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| model = BlipForConditionalGeneration.from_pretrained("saved_model_phase2") | |
| processor = BlipProcessor.from_pretrained("saved_model_phase2") | |
| model.to(device) | |
| model.eval() | |
| return model, processor, device | |
| def load_toxicity_model(): | |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| tokenizer = AutoTokenizer.from_pretrained("unitary/toxic-bert") | |
| model = AutoModelForSequenceClassification.from_pretrained("unitary/toxic-bert") | |
| model.to(device) | |
| model.eval() | |
| return model, tokenizer, device | |
| caption_model, caption_processor, device = load_caption_model() | |
| tox_model, tox_tokenizer, tox_device = load_toxicity_model() | |
| st.title("🖼️ Advanced Image Captioning Demo") | |
| st.write("Fine-tuned BLIP with Beam Search + Toxicity Filtering") | |
| st.sidebar.header("⚙️ Generation Settings") | |
| num_beams = st.sidebar.slider("Beam Size", 1, 10, 5) | |
| max_length = st.sidebar.slider("Max Length", 10, 50, 20) | |
| length_penalty = st.sidebar.slider("Length Penalty", 0.5, 2.0, 1.0, step=0.1) | |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) | |
| if uploaded_file: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| st.image(image, caption="Uploaded Image", width="stretch") | |
| if st.button("Generate Caption"): | |
| # Generate caption | |
| with st.spinner("Generating caption..."): | |
| inputs = caption_processor( | |
| images=image, | |
| return_tensors="pt", | |
| ).to(device) | |
| with torch.no_grad(): | |
| output_ids = caption_model.generate( | |
| **inputs, | |
| num_beams=num_beams, | |
| max_length=max_length, | |
| length_penalty=length_penalty, | |
| ) | |
| caption = caption_processor.decode( | |
| output_ids[0], | |
| skip_special_tokens=True, | |
| ) | |
| # Confidence score (stable) | |
| with torch.no_grad(): | |
| loss_inputs = caption_processor( | |
| images=image, | |
| text=caption, | |
| return_tensors="pt", | |
| ).to(device) | |
| outputs = caption_model( | |
| pixel_values=loss_inputs["pixel_values"], | |
| input_ids=loss_inputs["input_ids"], | |
| attention_mask=loss_inputs["attention_mask"], | |
| labels=loss_inputs["input_ids"], | |
| ) | |
| loss = outputs.loss | |
| confidence = torch.exp(-loss).item() if loss is not None else 0.0 | |
| # Toxicity check | |
| tox_inputs = tox_tokenizer( | |
| caption, | |
| return_tensors="pt", | |
| truncation=True, | |
| ).to(tox_device) | |
| with torch.no_grad(): | |
| tox_outputs = tox_model(**tox_inputs) | |
| probs = F.softmax(tox_outputs.logits, dim=-1) | |
| toxic_score = probs[0][1].item() | |
| # Display caption | |
| if toxic_score > 0.6: | |
| st.error("⚠️ Generated caption flagged as potentially toxic.") | |
| st.markdown("### 🚫 Caption Blocked") | |
| else: | |
| st.success("Caption Generated") | |
| st.markdown(f"### 📝 {caption}") | |
| st.caption(f"Toxicity Score: {toxic_score:.2f}") | |
| st.caption(f"Confidence Score: {confidence:.2f}") | |
| # Vision attention heatmap | |
| with torch.no_grad(): | |
| vision_outputs = caption_model.vision_model( | |
| inputs["pixel_values"], | |
| output_attentions=True, | |
| return_dict=True, | |
| ) | |
| attentions = vision_outputs.attentions[-1] | |
| attn = attentions[0].mean(0) | |
| cls_attn = attn[0, 1:] | |
| attn_map = cls_attn.cpu().numpy() | |
| attn_map = attn_map / attn_map.max() | |
| size = int(np.sqrt(len(attn_map))) | |
| fig, ax = plt.subplots() | |
| ax.imshow(attn_map.reshape(size, size), cmap="viridis") | |
| ax.set_title("Vision Attention Heatmap") | |
| ax.axis("off") | |
| st.pyplot(fig) | |