imagesearch / app.py
yaswanth8390's picture
Update app.py
4b5e7d6 verified
# === Updated app.py with inline model testing and label name prediction ===
from flask import Flask, request, render_template, send_from_directory, redirect, url_for, jsonify
import os, shutil, uuid, zipfile, json, time, csv
import tensorflow as tf
import numpy as np
from PIL import Image
from flask import jsonify
import glob
import difflib
from flask import request, jsonify, url_for
import json
import os
from flask import send_from_directory
app = Flask(__name__)
UPLOAD_FOLDER = 'uploads'
OUTPUT_FOLDER = 'output'
MODEL_FOLDER = os.path.join(OUTPUT_FOLDER, 'models')
CSV_FOLDER = os.path.join(OUTPUT_FOLDER, 'csvs')
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
os.makedirs(MODEL_FOLDER, exist_ok=True)
os.makedirs(CSV_FOLDER, exist_ok=True)
@app.route("/")
def index():
return render_template("index.html")
@app.route("/upload", methods=["POST"])
def upload_files():
if not request.files:
return "No files received", 400
epoch_count = int(request.form.get("epochs", 5))
session_id = uuid.uuid4().hex[:8]
upload_path = os.path.join(UPLOAD_FOLDER, session_id)
os.makedirs(upload_path, exist_ok=True)
for key in request.files:
file = request.files[key]
folder_path, filename = os.path.split(key)
save_folder = os.path.join(upload_path, folder_path)
os.makedirs(save_folder, exist_ok=True)
file.save(os.path.join(save_folder, filename))
model_filename, accuracy, summary_path, label_map_file, training_time, csv_file, history_data = train_model(upload_path, session_id, epoch_count)
zip_filename = f"model_package_{session_id}.zip"
zip_path = os.path.join(OUTPUT_FOLDER, zip_filename)
with zipfile.ZipFile(zip_path, 'w') as zipf:
zipf.write(os.path.join(MODEL_FOLDER, model_filename), arcname=model_filename)
zipf.write(os.path.join(OUTPUT_FOLDER, label_map_file), arcname=label_map_file)
zipf.write(os.path.join(CSV_FOLDER, csv_file), arcname=csv_file)
return redirect(url_for("result_page", accuracy=accuracy, zipname=zip_filename,
summary_id=session_id, time=training_time, csv=csv_file,
model_file=model_filename))
@app.route('/chat_search', methods=['POST'])
def chat_search():
"""
Chat-like image search endpoint.
Expects JSON:
{
"query": "fridge",
"session_id": "abc12345"
}
"""
data = request.get_json()
query = data.get("query", "").lower()
session_id = data.get("session_id")
if not query or not session_id:
return jsonify({"error": "Query and session_id required"}), 400
# Folder where uploaded images for this session are stored
session_folder = os.path.join("uploads", session_id)
if not os.path.exists(session_folder):
return jsonify({"error": "Session folder not found"}), 400
# All class folders in this session
class_names = [d for d in os.listdir(session_folder)
if os.path.isdir(os.path.join(session_folder, d))]
# First try simple substring match
matches = [cls for cls in class_names if query in cls.lower()]
# If nothing, fallback to fuzzy match
if not matches:
matches = difflib.get_close_matches(query, class_names, n=5, cutoff=0.4)
if not matches:
return jsonify({"results": []}) # No matches found
results = []
for cls in matches:
class_folder = os.path.join(session_folder, cls)
for img_file in os.listdir(class_folder)[:10]: # Limit to first 10 images per class
if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(class_folder, img_file)
# Generate URL for browser
rel_path = os.path.relpath(img_path, '.') # relative path from app root
results.append({
"class": cls,
"url": url_for('serve_uploads', filename=f"{session_id}/{cls}/{img_file}")
})
return jsonify({"results": results})
@app.route('/uploads/<path:filename>')
def serve_uploads(filename):
return send_from_directory('uploads', filename)
@app.route("/result")
def result_page():
accuracy = request.args.get("accuracy")
zipname = request.args.get("zipname")
summary_id = request.args.get("summary_id")
training_time = request.args.get("time")
csv_file = request.args.get("csv")
model_file = request.args.get("model_file")
with open(os.path.join(OUTPUT_FOLDER, "labels.json"), encoding='utf-8') as f:
label_map = json.load(f)
summary_path = os.path.join(OUTPUT_FOLDER, f"{summary_id}.txt")
with open(summary_path, encoding='utf-8') as f:
model_summary = f.read()
csv_path = os.path.join(CSV_FOLDER, csv_file)
history_data = []
with open(csv_path, encoding='utf-8') as f:
reader = csv.reader(f)
next(reader)
for row in reader:
history_data.append({'epoch': row[0], 'accuracy': row[1], 'loss': row[2]})
return render_template("result.html", accuracy=accuracy, zipname=zipname,
model_summary=model_summary, label_map=label_map,
training_time=training_time, csv_file=csv_file,
history=history_data, model_file=model_file,
session_id=summary_id) # <-- pass it here
@app.route("/test_model", methods=["POST"])
def test_model():
model_path = request.form.get("model_path")
test_image = request.files.get("test_image")
if not model_path or not test_image:
return jsonify({'error': 'Missing model or image'}), 400
full_model_path = os.path.join(MODEL_FOLDER, model_path)
img_path = os.path.join(UPLOAD_FOLDER, "test_img.png")
test_image.save(img_path)
try:
with open(os.path.join(OUTPUT_FOLDER, "labels.json"), encoding='utf-8') as f:
label_map = json.load(f)
index_to_label = {v: k for k, v in label_map.items()}
model = tf.keras.models.load_model(full_model_path)
img = Image.open(img_path).convert("RGB").resize((64, 64))
img_array = np.array(img) / 255.0
prediction = model.predict(np.expand_dims(img_array, axis=0))[0]
predicted_index = int(np.argmax(prediction))
predicted_label = index_to_label.get(predicted_index, str(predicted_index))
confidence = float(np.max(prediction)) * 100
except Exception as e:
return jsonify({'error': str(e)}), 500
finally:
os.remove(img_path)
return jsonify({
'label': predicted_label,
'confidence': f"{confidence:.2f}"
})
@app.route("/download/<filename>")
def download_file(filename):
return send_from_directory(OUTPUT_FOLDER, filename, as_attachment=True)
def train_model(data_folder, session_id, epochs):
image_data = []
labels = []
class_names = sorted([d for d in os.listdir(data_folder) if os.path.isdir(os.path.join(data_folder, d))])
class_to_index = {name: idx for idx, name in enumerate(class_names)}
for folder in class_names:
folder_path = os.path.join(data_folder, folder)
for fname in os.listdir(folder_path):
if fname.lower().endswith(('png', 'jpg', 'jpeg')):
img_path = os.path.join(folder_path, fname)
try:
img = Image.open(img_path).convert("RGB").resize((64, 64))
img_array = np.array(img) / 255.0
image_data.append(img_array)
labels.append(class_to_index[folder])
except:
continue
X = np.array(image_data)
y = np.array(labels)
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(64, 64, 3)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(len(class_names), activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
start = time.time()
history = model.fit(X, y, epochs=epochs, verbose=1)
training_time = round(time.time() - start, 2)
final_accuracy = round(history.history['accuracy'][-1] * 100, 2)
model_filename = f"trained_model_{session_id}.h5"
model.save(os.path.join(MODEL_FOLDER, model_filename))
label_path = os.path.join(OUTPUT_FOLDER, "labels.json")
with open(label_path, 'w') as f:
json.dump(class_to_index, f)
summary_path = os.path.join(OUTPUT_FOLDER, f"{session_id}.txt")
with open(summary_path, 'w', encoding='utf-8') as f: # Add encoding='utf-8'
model.summary(print_fn=lambda x: f.write(x + "\n"))
csv_file = f"training_log_{session_id}.csv"
csv_path = os.path.join(CSV_FOLDER, csv_file)
with open(csv_path, 'w', newline='', encoding='utf-8') as f: # Add encoding='utf-8'
writer = csv.writer(f)
writer.writerow(['Epoch', 'Accuracy', 'Loss'])
for i in range(epochs):
writer.writerow([i+1, history.history['accuracy'][i], history.history['loss'][i]])
return model_filename, final_accuracy, summary_path, "labels.json", training_time, csv_file, history.history
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=8000)