Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor | |
| from PIL import Image | |
| from torchvision.transforms.functional import crop | |
| import gradio as gr | |
| import base64 | |
| import io | |
| from huggingface_hub import hf_hub_download | |
| import zipfile | |
| import os | |
| # Global variables for models | |
| object_detection_model = None | |
| captioning_model = None | |
| tokenizer = None | |
| captioning_processor = None | |
| # Load models during initialization | |
| def init(): | |
| global object_detection_model, captioning_model, tokenizer, captioning_processor | |
| # Step 1: Load the YOLOv5 model from Hugging Face | |
| try: | |
| print("Loading YOLOv5 model...") | |
| # Get Hugging Face auth token from environment variable | |
| auth_token = os.getenv("HF_AUTH_TOKEN") | |
| if not auth_token: | |
| print("Error: HF_AUTH_TOKEN environment variable not set.") | |
| object_detection_model = None | |
| else: | |
| # Download the zip file from Hugging Face | |
| zip_path = hf_hub_download(repo_id='Mexbow/Yolov5_object_detection', filename='yolov5.zip', use_auth_token=auth_token) | |
| # Extract the YOLOv5 model | |
| extract_path = './yolov5_model' # Specify extraction path | |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
| os.makedirs(extract_path, exist_ok=True) | |
| zip_ref.extractall(extract_path) | |
| # Load the YOLOv5 model | |
| model_path = os.path.join(extract_path, 'yolov5/weights/best14.pt') | |
| if not os.path.exists(model_path): | |
| print(f"Error: YOLOv5 model file not found at {model_path}") | |
| object_detection_model = None | |
| else: | |
| object_detection_model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path, trust_repo=True) | |
| print("YOLOv5 model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading YOLOv5 model: {e}") | |
| object_detection_model = None | |
| # Step 2: Load the ViT-GPT2 captioning model from Hugging Face | |
| try: | |
| print("Loading ViT-GPT2 model...") | |
| captioning_model = VisionEncoderDecoderModel.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") | |
| tokenizer = AutoTokenizer.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") | |
| captioning_processor = AutoImageProcessor.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") | |
| print("ViT-GPT2 model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading captioning model: {e}") | |
| captioning_model, tokenizer, captioning_processor = None, None, None | |
| # Utility function to crop objects from the image based on bounding boxes | |
| def crop_objects(image, boxes): | |
| cropped_images = [] | |
| for box in boxes: | |
| left, top, right, bottom = box | |
| cropped_image = image.crop((left, top, right, bottom)) | |
| cropped_images.append(cropped_image) | |
| return cropped_images | |
| # Gradio interface function | |
| def process_image(image): | |
| global object_detection_model, captioning_model, tokenizer, captioning_processor | |
| # Ensure models are loaded | |
| if object_detection_model is None or captioning_model is None or tokenizer is None or captioning_processor is None: | |
| return None, {"error": "Models are not loaded properly"}, None | |
| try: | |
| # Step 1: Perform object detection with YOLOv5 | |
| results = object_detection_model(image) | |
| boxes = results.xyxy[0][:, :4].cpu().numpy() # Bounding boxes | |
| labels = [results.names[int(class_id)] for class_id in results.xyxy[0][:, 5].cpu().numpy().astype(int)] # Class names | |
| scores = results.xyxy[0][:, 4].cpu().numpy() # Confidence scores | |
| # Step 2: Generate caption for the whole image | |
| original_inputs = captioning_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| original_caption_ids = captioning_model.generate(**original_inputs) | |
| original_caption = tokenizer.decode(original_caption_ids[0], skip_special_tokens=True) | |
| # Step 3: Crop detected objects and generate captions for each object | |
| cropped_images = crop_objects(image, boxes) | |
| captions = [] | |
| for cropped_image in cropped_images: | |
| inputs = captioning_processor(images=cropped_image, return_tensors="pt") | |
| with torch.no_grad(): | |
| caption_ids = captioning_model.generate(**inputs) | |
| caption = tokenizer.decode(caption_ids[0], skip_special_tokens=True) | |
| captions.append(caption) | |
| # Prepare the result for visualization as a formatted string | |
| detection_results = "" | |
| for i, (label, box, score, caption) in enumerate(zip(labels, boxes, scores, captions)): | |
| detection_results += f"Object {i + 1}: {label} - Caption: {caption}\n" | |
| # Render image with bounding boxes | |
| result_image = results.render()[0] | |
| # Return the image with detections, formatted captions, and the whole image caption | |
| return result_image, detection_results, original_caption | |
| except Exception as e: | |
| return None, {"error": str(e)}, None | |
| # Initialize models | |
| init() | |
| # Gradio Interface | |
| interface = gr.Interface( | |
| fn=process_image, # Function to run | |
| inputs=gr.Image(type="pil"), # Input: Image upload | |
| outputs=[ | |
| gr.Image(type="pil", label="Detected Objects"), # Output 1: Image with bounding boxes | |
| gr.Textbox(label="Object Captions & Bounding Boxes", lines=10), # Output 2: Formatted captions | |
| gr.Textbox(label="Whole Image Caption") # Output 3: Caption for the whole image | |
| ], | |
| live=True | |
| ) | |
| # Launch the Gradio app | |
| interface.launch() | |