Spaces:
Running
Running
| from flask import Flask, render_template, request, jsonify | |
| from PIL import Image | |
| from io import BytesIO | |
| import torch | |
| from torchvision import models, transforms | |
| from torchvision.models.detection import fasterrcnn_resnet50_fpn | |
| import os | |
| app = Flask(__name__) | |
| # Load ImageNet class index | |
| def load_imagenet_class_index(): | |
| class_index_path = 'imagenet_classes.txt' | |
| if not os.path.exists(class_index_path): | |
| raise FileNotFoundError(f"ImageNet class index file not found at {class_index_path}") | |
| with open(class_index_path) as f: | |
| classes = [line.strip() for line in f.readlines()] | |
| return classes | |
| imagenet_classes = load_imagenet_class_index() | |
| # Load pre-trained models | |
| resnet = models.resnet50(pretrained=True) | |
| resnet.eval() | |
| fasterrcnn = fasterrcnn_resnet50_fpn(pretrained=True) | |
| fasterrcnn.eval() | |
| # Image transformation | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # COCO dataset class names | |
| COCO_INSTANCE_CATEGORY_NAMES = [ | |
| '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', | |
| 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |
| 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', | |
| 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', | |
| 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', | |
| 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', | |
| 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | |
| 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', | |
| 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' | |
| ] | |
| # Function for real image analysis | |
| def real_image_analysis(image): | |
| # Prepare image for classification | |
| img_t = transform(image) | |
| batch_t = torch.unsqueeze(img_t, 0) | |
| # Classification | |
| with torch.no_grad(): | |
| output = resnet(batch_t) | |
| # Get top 3 predictions | |
| _, indices = torch.sort(output, descending=True) | |
| percentages = torch.nn.functional.softmax(output, dim=1)[0] * 100 | |
| objects = [imagenet_classes[idx.item()] for idx in indices[0][:3]] | |
| # Object detection using Faster R-CNN | |
| img_tensor = transforms.ToTensor()(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| prediction = fasterrcnn(img_tensor) | |
| # Get detected objects | |
| detected_objects = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in prediction[0]['labels']] | |
| objects.extend(detected_objects) | |
| objects = list(set(objects)) # Remove duplicates | |
| # Get dominant colors | |
| colors = get_dominant_colors(image) | |
| # Determine scene (indoor/outdoor) | |
| scene = "outdoor" if any(obj in ['sky', 'tree', 'grass', 'mountain'] for obj in objects) else "indoor" | |
| return { | |
| "objects": objects[:5], # Limit to top 5 objects | |
| "colors": colors, | |
| "scene": scene | |
| } | |
| # Function to get dominant colors from an image | |
| def get_dominant_colors(image, num_colors=3): | |
| # Resize image to speed up processing | |
| img = image.copy() | |
| img.thumbnail((100, 100)) | |
| # Get colors from the image | |
| paletted = img.convert('P', palette=Image.ADAPTIVE, colors=num_colors) | |
| palette = paletted.getpalette() | |
| color_counts = sorted(paletted.getcolors(), reverse=True) | |
| colors = [] | |
| for i in range(num_colors): | |
| palette_index = color_counts[i][1] | |
| dominant_color = palette[palette_index*3:palette_index*3+3] | |
| colors.append(rgb_to_name(dominant_color)) | |
| return colors | |
| # Function to convert RGB to color name (simplified) | |
| def rgb_to_name(rgb): | |
| r, g, b = rgb | |
| if r > g and r > b: | |
| return "red" | |
| elif g > r and g > b: | |
| return "green" | |
| elif b > r and b > g: | |
| return "blue" | |
| else: | |
| return "gray" | |
| # Function to simulate the generation of answers from metadata | |
| def generate_answer_from_metadata(metadata, question, complexity): | |
| prompt = f""" | |
| The image contains the following objects: {', '.join(metadata['objects'])}. | |
| The dominant colors are {', '.join(metadata['colors'])}. | |
| It appears to be an {metadata['scene']} scene. | |
| Based on this, provide a {complexity.lower()} response to the following question: {question} | |
| """ | |
| # Since `client` is not defined, we can simulate a response here | |
| # Replace this section with the actual client code if using an API | |
| return f"Simulated answer based on metadata: {metadata}. Question: {question}, Complexity: {complexity}." | |
| # Flask routes | |
| def index(): | |
| return render_template('index.html') | |
| def ask_question(): | |
| image = request.files.get('image') | |
| question = request.form.get('question') | |
| complexity = request.form.get('complexity', 'Default') | |
| if not image or not question: | |
| return jsonify({"error": "Missing image or question"}), 400 | |
| # Process the image | |
| image = Image.open(image).convert("RGB") | |
| # Perform real image analysis | |
| metadata = real_image_analysis(image) | |
| # Generate the answer | |
| try: | |
| answer = generate_answer_from_metadata(metadata, question, complexity) | |
| return jsonify({"answer": answer}) | |
| except Exception as e: | |
| print(f"Error generating answer: {str(e)}") | |
| return jsonify({"error": "Failed to generate answer"}), 500 | |
| if __name__ == '__main__': | |
| app.run(debug=True) | |