import gradio as gr from transformers import BlipProcessor, BlipForConditionalGeneration, RobertaTokenizer, RobertaForSequenceClassification from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch # Load the image captioning model and tokenizer caption_model_name = "Salesforce/blip-image-captioning-large" caption_processor = BlipProcessor.from_pretrained(caption_model_name) caption_model = BlipForConditionalGeneration.from_pretrained(caption_model_name) # Load the emotion analysis model and tokenizer emotion_model_name = "SamLowe/roberta-base-go_emotions" emotion_tokenizer = AutoTokenizer.from_pretrained(emotion_model_name) emotion_model = AutoModelForSequenceClassification.from_pretrained(emotion_model_name) def generate_caption_and_analyze_emotions(image): # Preprocess the image for caption generation caption_inputs = caption_processor(images=image, return_tensors="pt") # Generate caption using the caption model caption = caption_model.generate(**caption_inputs) # Decode the output caption decoded_caption = caption_processor.decode(caption[0], skip_special_tokens=True) # Analyze emotions of the generated caption # Preprocess the caption for emotion analysis emotion_inputs = emotion_tokenizer.encode_plus( decoded_caption, max_length=128, padding="max_length", truncation=True, return_tensors="pt" ) emotion_outputs = emotion_model(**emotion_inputs) # Get the predicted emotion label emotion_label_id = emotion_outputs.logits.argmax().item() emotion_label = emotion_tokenizer.decode(emotion_label_id) # Prepare the final output with sentiment information final_output = f"The sentiment in the provided image shows: {emotion_label}.\n\nGenerated Caption: {decoded_caption}" return final_output # Define the Gradio interface inputs = gr.inputs.Image(label="Upload an image") outputs = gr.outputs.Textbox(label="Generated Caption and Sentiment Analysis") # Create the Gradio app app = gr.Interface(fn=generate_caption_and_analyze_emotions, inputs=inputs, outputs=outputs) # Launch the Gradio app if __name__ == "__main__": app.launch()