Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| import google.generativeai as genai | |
| from pathlib import Path | |
| import gradio as gr | |
| from transformers import DetrImageProcessor, DetrForObjectDetection | |
| import torch | |
| from PIL import Image, ImageDraw | |
| import requests | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Get the API key from the environment | |
| API_KEY = os.getenv("GOOGLE_API_KEY") | |
| # Set up the generative AI model with the API key | |
| genai.configure(api_key=API_KEY) | |
| # Set up the generative model | |
| generation_config = { | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| "top_k": 40, | |
| "max_output_tokens": 4000, | |
| } | |
| safety_settings = [ | |
| { | |
| "category": "HARM_CATEGORY_HARASSMENT", | |
| "threshold": "BLOCK_MEDIUM_AND_ABOVE" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_HATE_SPEECH", | |
| "threshold": "BLOCK_MEDIUM_AND_ABOVE" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
| "threshold": "BLOCK_MEDIUM_AND_ABOVE" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
| "threshold": "BLOCK_MEDIUM_AND_ABOVE" | |
| } | |
| ] | |
| model = genai.GenerativeModel(model_name="gemini-1.5-flash-latest", | |
| generation_config=generation_config, | |
| safety_settings=safety_settings) | |
| input_prompt_template = """give me the info of the car/truck (if an info is not available juste write "introuvable"): | |
| - plate: | |
| - model: | |
| - color: """ | |
| def input_image_setup(file_loc): | |
| if not (img := Path(file_loc)).exists(): | |
| raise FileNotFoundError(f"Could not find image: {img}") | |
| image_parts = [ | |
| { | |
| "mime_type": "image/jpeg", | |
| "data": Path(file_loc).read_bytes() | |
| } | |
| ] | |
| return image_parts | |
| def generate_gemini_response(input_prompt, image): | |
| image_parts = [ | |
| { | |
| "mime_type": "image/jpeg", | |
| "data": image | |
| } | |
| ] | |
| prompt_parts = [input_prompt, image_parts[0]] | |
| response = model.generate_content(prompt_parts) | |
| return response.text | |
| # Object detection part | |
| def detect_objects(image): | |
| processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm") | |
| model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm") | |
| inputs = processor(images=image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| target_sizes = torch.tensor([image.size[::-1]]) | |
| results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] | |
| detected_cars = [] | |
| draw = ImageDraw.Draw(image) | |
| # Loop through detections and filter only "car" class (ID 3 for COCO dataset) | |
| for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
| if (model.config.id2label[label.item()] == 'car' or model.config.id2label[label.item()] == 'truck' ) and score.item() > 0.9: | |
| box = [round(i, 2) for i in box.tolist()] | |
| # Crop the detected car | |
| cropped_car = image.crop(box) | |
| # Convert the cropped image to bytes | |
| cropped_car_bytes = image_to_bytes(cropped_car) | |
| detected_cars.append((cropped_car_bytes, box)) | |
| # Draw bounding box around the car | |
| draw.rectangle(box, outline="red", width=3) | |
| draw.text((box[0], box[1]), f"véhicule: {round(score.item(), 2)}", fill="red") | |
| return image, detected_cars | |
| def image_to_bytes(img): | |
| # Convert a PIL image to bytes | |
| from io import BytesIO | |
| img_bytes = BytesIO() | |
| img.save(img_bytes, format="JPEG") | |
| img_bytes = img_bytes.getvalue() | |
| return img_bytes | |
| def upload_file(files): | |
| if not files: | |
| return None, "Image not uploaded" | |
| file_paths = [file.name for file in files] | |
| return file_paths[0] | |
| def process_generate(files): | |
| if not files: | |
| return None, "Image not uploaded" | |
| # Load the image | |
| file_path = files[0].name | |
| image = Image.open(file_path) | |
| # Detect cars and return cropped car images | |
| detected_image, detected_cars = detect_objects(image) | |
| # Generate responses for each car | |
| car_info_list = [] | |
| for car_bytes, box in detected_cars: | |
| car_info = generate_gemini_response(input_prompt_template, car_bytes) | |
| car_info_list.append(f"véhicule aux coordonnées {box}:\n{car_info}\n") | |
| return detected_image, "\n".join(car_info_list) | |
| with gr.Blocks() as demo: | |
| header = gr.Label("RADARPICK: Vous avez pris en flag!") | |
| image_output = gr.Image() | |
| upload_button = gr.UploadButton("Click to upload an image", file_types=["image"], file_count="multiple") | |
| generate_button = gr.Button("Generate") | |
| file_output = gr.Textbox(label="Generated Content") | |
| upload_button.upload(fn=lambda files: files[0].name if files else None, inputs=[upload_button], outputs=image_output) | |
| generate_button.click(fn=process_generate, inputs=[upload_button], outputs=[image_output, file_output]) | |
| demo.launch() |