Spaces:
Running
Running
Upload app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import Flask, render_template, request, jsonify
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision import models, transforms
|
| 6 |
+
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
app = Flask(__name__)
|
| 10 |
+
|
| 11 |
+
# Load ImageNet class index
|
| 12 |
+
def load_imagenet_class_index():
|
| 13 |
+
class_index_path = 'imagenet_classes.txt'
|
| 14 |
+
if not os.path.exists(class_index_path):
|
| 15 |
+
raise FileNotFoundError(f"ImageNet class index file not found at {class_index_path}")
|
| 16 |
+
|
| 17 |
+
with open(class_index_path) as f:
|
| 18 |
+
classes = [line.strip() for line in f.readlines()]
|
| 19 |
+
return classes
|
| 20 |
+
|
| 21 |
+
imagenet_classes = load_imagenet_class_index()
|
| 22 |
+
|
| 23 |
+
# Load pre-trained models
|
| 24 |
+
resnet = models.resnet50(pretrained=True)
|
| 25 |
+
resnet.eval()
|
| 26 |
+
fasterrcnn = fasterrcnn_resnet50_fpn(pretrained=True)
|
| 27 |
+
fasterrcnn.eval()
|
| 28 |
+
|
| 29 |
+
# Image transformation
|
| 30 |
+
transform = transforms.Compose([
|
| 31 |
+
transforms.Resize(256),
|
| 32 |
+
transforms.CenterCrop(224),
|
| 33 |
+
transforms.ToTensor(),
|
| 34 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 35 |
+
])
|
| 36 |
+
|
| 37 |
+
# COCO dataset class names
|
| 38 |
+
COCO_INSTANCE_CATEGORY_NAMES = [
|
| 39 |
+
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
| 40 |
+
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
|
| 41 |
+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
| 42 |
+
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
|
| 43 |
+
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
| 44 |
+
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
| 45 |
+
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
| 46 |
+
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
| 47 |
+
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
|
| 48 |
+
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
| 49 |
+
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
|
| 50 |
+
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
# Function for real image analysis
|
| 54 |
+
def real_image_analysis(image):
|
| 55 |
+
# Prepare image for classification
|
| 56 |
+
img_t = transform(image)
|
| 57 |
+
batch_t = torch.unsqueeze(img_t, 0)
|
| 58 |
+
|
| 59 |
+
# Classification
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
output = resnet(batch_t)
|
| 62 |
+
|
| 63 |
+
# Get top 3 predictions
|
| 64 |
+
_, indices = torch.sort(output, descending=True)
|
| 65 |
+
percentages = torch.nn.functional.softmax(output, dim=1)[0] * 100
|
| 66 |
+
objects = [imagenet_classes[idx.item()] for idx in indices[0][:3]]
|
| 67 |
+
|
| 68 |
+
# Object detection using Faster R-CNN
|
| 69 |
+
img_tensor = transforms.ToTensor()(image).unsqueeze(0)
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
prediction = fasterrcnn(img_tensor)
|
| 72 |
+
|
| 73 |
+
# Get detected objects
|
| 74 |
+
detected_objects = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in prediction[0]['labels']]
|
| 75 |
+
objects.extend(detected_objects)
|
| 76 |
+
objects = list(set(objects)) # Remove duplicates
|
| 77 |
+
|
| 78 |
+
# Get dominant colors
|
| 79 |
+
colors = get_dominant_colors(image)
|
| 80 |
+
|
| 81 |
+
# Determine scene (indoor/outdoor)
|
| 82 |
+
scene = "outdoor" if any(obj in ['sky', 'tree', 'grass', 'mountain'] for obj in objects) else "indoor"
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"objects": objects[:5], # Limit to top 5 objects
|
| 86 |
+
"colors": colors,
|
| 87 |
+
"scene": scene
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# Function to get dominant colors from an image
|
| 91 |
+
def get_dominant_colors(image, num_colors=3):
|
| 92 |
+
# Resize image to speed up processing
|
| 93 |
+
img = image.copy()
|
| 94 |
+
img.thumbnail((100, 100))
|
| 95 |
+
|
| 96 |
+
# Get colors from the image
|
| 97 |
+
paletted = img.convert('P', palette=Image.ADAPTIVE, colors=num_colors)
|
| 98 |
+
palette = paletted.getpalette()
|
| 99 |
+
color_counts = sorted(paletted.getcolors(), reverse=True)
|
| 100 |
+
colors = []
|
| 101 |
+
for i in range(num_colors):
|
| 102 |
+
palette_index = color_counts[i][1]
|
| 103 |
+
dominant_color = palette[palette_index*3:palette_index*3+3]
|
| 104 |
+
colors.append(rgb_to_name(dominant_color))
|
| 105 |
+
return colors
|
| 106 |
+
|
| 107 |
+
# Function to convert RGB to color name (simplified)
|
| 108 |
+
def rgb_to_name(rgb):
|
| 109 |
+
r, g, b = rgb
|
| 110 |
+
if r > g and r > b:
|
| 111 |
+
return "red"
|
| 112 |
+
elif g > r and g > b:
|
| 113 |
+
return "green"
|
| 114 |
+
elif b > r and b > g:
|
| 115 |
+
return "blue"
|
| 116 |
+
else:
|
| 117 |
+
return "gray"
|
| 118 |
+
|
| 119 |
+
# Function to simulate the generation of answers from metadata
|
| 120 |
+
def generate_answer_from_metadata(metadata, question, complexity):
|
| 121 |
+
prompt = f"""
|
| 122 |
+
The image contains the following objects: {', '.join(metadata['objects'])}.
|
| 123 |
+
The dominant colors are {', '.join(metadata['colors'])}.
|
| 124 |
+
It appears to be an {metadata['scene']} scene.
|
| 125 |
+
|
| 126 |
+
Based on this, provide a {complexity.lower()} response to the following question: {question}
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
# Since `client` is not defined, we can simulate a response here
|
| 130 |
+
# Replace this section with the actual client code if using an API
|
| 131 |
+
return f"Simulated answer based on metadata: {metadata}. Question: {question}, Complexity: {complexity}."
|
| 132 |
+
|
| 133 |
+
# Flask routes
|
| 134 |
+
@app.route('/')
|
| 135 |
+
def index():
|
| 136 |
+
return render_template('index.html')
|
| 137 |
+
|
| 138 |
+
@app.route('/ask', methods=['POST'])
|
| 139 |
+
def ask_question():
|
| 140 |
+
image = request.files.get('image')
|
| 141 |
+
question = request.form.get('question')
|
| 142 |
+
complexity = request.form.get('complexity', 'Default')
|
| 143 |
+
|
| 144 |
+
if not image or not question:
|
| 145 |
+
return jsonify({"error": "Missing image or question"}), 400
|
| 146 |
+
|
| 147 |
+
# Process the image
|
| 148 |
+
image = Image.open(image).convert("RGB")
|
| 149 |
+
|
| 150 |
+
# Perform real image analysis
|
| 151 |
+
metadata = real_image_analysis(image)
|
| 152 |
+
|
| 153 |
+
# Generate the answer
|
| 154 |
+
try:
|
| 155 |
+
answer = generate_answer_from_metadata(metadata, question, complexity)
|
| 156 |
+
return jsonify({"answer": answer})
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Error generating answer: {str(e)}")
|
| 159 |
+
return jsonify({"error": "Failed to generate answer"}), 500
|
| 160 |
+
|
| 161 |
+
if __name__ == '__main__':
|
| 162 |
+
app.run(debug=True)
|