Spaces:
Build error
Build error
| import gradio as gr | |
| from transformers import GitProcessor, GitForCausalLM, BlipProcessor, BlipForConditionalGeneration | |
| from PIL import Image | |
| import torch | |
| from sentence_transformers import SentenceTransformer, util | |
| import nltk | |
| # Ensure nltk stopwords are downloaded | |
| nltk.download('stopwords') | |
| # Load the stop words from nltk | |
| stop_words = set(stopwords.words('english')) | |
| # Device setup | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load the processor and model for GIT | |
| processor_git = GitProcessor.from_pretrained("microsoft/git-base-coco") | |
| model_git = GitForCausalLM.from_pretrained("microsoft/git-base-coco").to(device) | |
| # Load the processor and model for BLIP | |
| processor_blip = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
| model_blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device) | |
| # Load the SentenceTransformer model | |
| model_sentence = SentenceTransformer('sentence-transformers/stsb-roberta-base') | |
| # Define categories and associated prompts | |
| category_prompts = { | |
| "nature": "Describe the beautiful aspects of the scene in nature. What might be happening outside of the picture?", | |
| "technology": "Explain how technology is being used here and what future implications it might have.", | |
| "kids": "Talk about the joy of childhood visible in this image, and imagine what the kids might do next." | |
| } | |
| # Function to generate prompt based on category | |
| def generate_prompt(image_category): | |
| if image_category in category_prompts: | |
| return category_prompts[image_category] | |
| else: | |
| return "Describe the image in detail and predict what could be happening beyond it." | |
| # Function to preprocess image | |
| def preprocess_image(image): | |
| inputs = processor_git(images=image, return_tensors="pt") | |
| return inputs['pixel_values'].to(device), image # Return both pixel values and image for display | |
| # Function to generate caption with GIT model | |
| def generate_caption_git(image, category): | |
| # Preprocess image | |
| pixel_values, processed_image = preprocess_image(image) | |
| # Generate prompt for the category | |
| prompt = generate_prompt(category) | |
| # Generate caption with GIT model | |
| inputs = processor_git(text=prompt, images=image, return_tensors="pt") | |
| inputs['pixel_values'] = pixel_values # Use preprocessed pixel values | |
| # Move inputs to GPU | |
| inputs = {key: val.to(device) for key, val in inputs.items()} | |
| # Generate caption | |
| generated_ids = model_git.generate( | |
| pixel_values=inputs['pixel_values'], | |
| max_length=300, | |
| num_beams=5, | |
| repetition_penalty=2.5 | |
| ) | |
| generated_text = processor_git.decode(generated_ids[0], skip_special_tokens=True) | |
| return generated_text, processed_image | |
| # Function to generate final caption using BLIP model | |
| def generate_caption_blip(image, git_caption): | |
| # Preprocess image for BLIP | |
| inputs = processor_blip(images=image, text=git_caption, return_tensors="pt").to(device) | |
| # Generate final caption with BLIP | |
| generated_ids = model_blip.generate( | |
| inputs['pixel_values'], | |
| max_length=300, # Increase to allow for longer captions | |
| num_beams=5, | |
| repetition_penalty=2.5, | |
| length_penalty=1.5, # Encourage longer outputs | |
| min_length=120, # Set minimum length to ensure at least 150 words | |
| early_stopping=True | |
| ) | |
| final_caption = processor_blip.decode(generated_ids[0], skip_special_tokens=True) | |
| return final_caption | |
| # Function to compute semantic similarity score | |
| def compute_semantic_similarity(generated_caption, candidate_answer): | |
| # Encode the sentences to get their embeddings | |
| generated_embedding = model_sentence.encode(generated_caption, convert_to_tensor=True) | |
| candidate_embedding = model_sentence.encode(candidate_answer, convert_to_tensor=True) | |
| # Compute the cosine similarity | |
| similarity_score = util.pytorch_cos_sim(generated_embedding, candidate_embedding).item() | |
| # Scale the similarity score to a percentage (0-100) | |
| return similarity_score * 100 | |
| # Gradio function | |
| def gradio_interface(image, category, candidate_answer): | |
| # Generate initial caption using GIT | |
| git_caption, _ = generate_caption_git(image, category) | |
| # Generate final caption using BLIP based on GIT output | |
| final_caption = generate_caption_blip(image, git_caption) | |
| # Compute the similarity between BLIP caption and candidate answer | |
| similarity_score = compute_semantic_similarity(final_caption, candidate_answer) | |
| return final_caption, similarity_score | |
| # Create Gradio interface | |
| image_input = gr.Image(type="pil") | |
| category_input = gr.Dropdown(choices=["nature", "technology", "kids"], label="Select Category") | |
| candidate_input = gr.Textbox(label="Enter your answer", placeholder="Type your answer here...") | |
| outputs = [gr.Textbox(label="BLIP Caption"), | |
| gr.Textbox(label="Semantic Similarity Score")] | |
| # Launch the Gradio interface | |
| gr.Interface(fn=gradio_interface, inputs=[image_input, category_input, candidate_input], outputs=outputs, title="Image Captioning with BLIP and Semantic Similarity", description="Upload an image, select a category, and input your answer to compare with the BLIP-generated caption.").launch() | |