alexnguyen25's picture
Initial commit without VDB
8c6673d
import app as gr
import torch
from PIL import Image
import chromadb
from scripts.qwen3_vl_embedding import Qwen3VLEmbedder
from scripts.qwen3_vl_reranker import Qwen3VLReranker
# Configuration
VDB_PATH = "./VDB" # ChromaDB will be in the Space
TERM_COUNT = 10
# Load ChromaDB
print("Loading ChromaDB collection...")
chroma_client = chromadb.PersistentClient(path=VDB_PATH)
collection = chroma_client.get_collection(name="aat_terms")
# Load models
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading Qwen embedding model on {device}...")
embedding_model = Qwen3VLEmbedder(
model_name_or_path="Qwen/Qwen3-VL-Embedding-2B",
dtype=torch.float16 if device == "cuda" else torch.float32,
)
print(f"Loading Qwen reranking model on {device}...")
reranking_model = Qwen3VLReranker(
model_name_or_path="Qwen/Qwen3-VL-Reranker-2B", dtype=torch.float16
)
print("Models loaded successfully!")
def process_multiple_images(images, state):
"""Process multiple uploaded images"""
state = {"all_results": {}, "current_index": 0}
if not images or len(images) == 0:
return None, "No images uploaded", gr.update(choices=[], value=[]), "", state
print(f"Processing {len(images)} images...")
for idx, image_path in enumerate(images):
try:
image = Image.open(image_path)
# Generate keywords for this image
art_query = {"image": image, "text": ""}
query_input = [art_query]
# Generate embeddings
image_features = embedding_model.process(query_input)
image_features = torch.nn.functional.normalize(image_features, p=2, dim=1)
# Query ChromaDB
results = collection.query(
query_embeddings=image_features.cpu().float().tolist(), n_results=10
)
# Collect results
labels = []
input_docs = []
if results["documents"]:
for doc, metadatas in zip(
results["documents"][0], results["metadatas"][0]
):
input_docs.append({"text": doc})
labels.append(metadatas["term_label"])
# Rerank
rerank_inputs = {
"instruction": "Retrieve Art & Architecture Thesaurus terms relevant to the given image.",
"query": art_query,
"documents": input_docs,
"fps": 1.0,
}
scores = reranking_model.process(rerank_inputs)
sorted_results = sorted(zip(scores, labels), reverse=True)
# Store results with all keywords selected by default
state["all_results"][idx] = {
"image": image,
"keywords": [label for _, label in sorted_results],
"scores": [score for score, _ in sorted_results],
"selected": [True] * len(sorted_results),
}
print(f"Processed image {idx + 1}/{len(images)}")
except Exception as e:
print(f"Error processing image {idx}: {e}")
state["all_results"][idx] = {
"image": (
Image.open(image_path)
if isinstance(image_path, str)
else image_path
),
"keywords": [],
"scores": [],
"selected": [],
}
# Show first image
img, status, checkbox_update = show_image(0, state)
return img, status, checkbox_update, "", state
def show_image(index, state):
"""Display a specific image and its keywords"""
all_results = state["all_results"]
if index not in all_results:
return None, f"No image at index {index}", gr.update(choices=[], value=[])
state["current_index"] = index
result = all_results[index]
# Create display strings for keywords
keyword_choices = []
for kw, score in zip(result["keywords"], result["scores"]):
keyword_choices.append(f"{kw} ({score * 100:.1f}%)")
# Get currently selected keywords
selected_keywords = [
keyword_choices[i] for i, sel in enumerate(result["selected"]) if sel
]
status = f"Image {index + 1} of {len(all_results)}"
return (
result["image"],
status,
gr.update(choices=keyword_choices, value=selected_keywords),
)
def update_selections(selected_keywords, state):
"""Update which keywords are selected for current image"""
all_results = state["all_results"]
current_index = state["current_index"]
if current_index not in all_results:
return state
result = all_results[current_index]
for i in range(len(result["keywords"])):
keyword_display = f"{result['keywords'][i]} ({result['scores'][i] * 100:.1f}%)"
result["selected"][i] = keyword_display in selected_keywords
return state
def next_image(state):
"""Go to next image"""
current_index = state["current_index"]
next_idx = current_index + 1
if next_idx < len(state["all_results"]):
img, status, cb = show_image(next_idx, state)
else:
img, status, cb = show_image(current_index, state)
return img, status, cb, state
def previous_image(state):
"""Go to previous image"""
current_index = state["current_index"]
prev_idx = current_index - 1
if prev_idx >= 0:
img, status, cb = show_image(prev_idx, state)
else:
img, status, cb = show_image(current_index, state)
return img, status, cb, state
def export_results(state):
"""Export final keywords for all images"""
all_results = state["all_results"]
output = []
for idx in sorted(all_results.keys()):
result = all_results[idx]
selected_kw = [
result["keywords"][i] for i, sel in enumerate(result["selected"]) if sel
]
output.append(f"Image {idx + 1}: {', '.join(selected_kw)}")
return "\n\n".join(output)
# Create Gradio interface
with gr.Blocks() as interface:
gr.Markdown("# MCAM Art Keyword Generator")
gr.Markdown(
"Upload multiple images, review keywords for each, and export selected keywords"
)
# Session state
state = gr.State({"all_results": {}, "current_index": 0})
with gr.Row():
with gr.Column():
upload_input = gr.File(
file_count="multiple",
file_types=["image"],
label="Upload Images (multiple files)",
)
process_btn = gr.Button("Process All Images", variant="primary")
with gr.Column():
status_text = gr.Textbox(label="Status", interactive=False)
with gr.Row():
with gr.Column(scale=1):
current_image = gr.Image(label="Current Image", type="pil")
with gr.Row():
prev_btn = gr.Button("← Previous")
next_btn = gr.Button("Next →")
with gr.Column(scale=1):
gr.Markdown("### Select Keywords to Keep")
keyword_checkboxes = gr.CheckboxGroup(
choices=[],
label="Keywords (check to keep, uncheck to remove)",
interactive=True,
)
with gr.Row():
export_btn = gr.Button("Export Selected Keywords", variant="primary")
export_output = gr.Textbox(label="Final Keywords for All Images", lines=10)
# Wire up the interface
process_btn.click(
fn=process_multiple_images,
inputs=[upload_input, state],
outputs=[current_image, status_text, keyword_checkboxes, export_output, state],
)
keyword_checkboxes.change(
fn=update_selections,
inputs=[keyword_checkboxes, state],
outputs=[state],
)
next_btn.click(
fn=next_image,
inputs=[state],
outputs=[current_image, status_text, keyword_checkboxes, state],
)
prev_btn.click(
fn=previous_image,
inputs=[state],
outputs=[current_image, status_text, keyword_checkboxes, state],
)
export_btn.click(
fn=export_results,
inputs=[state],
outputs=[export_output],
)
if __name__ == "__main__":
interface.launch()