Rec_Sys_Flo2 / app.py
bgaspra's picture
Update app.py
70d5323 verified
raw
history blame
5.22 kB
import torch
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import pandas as pd
from datasets import load_dataset
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import warnings
warnings.filterwarnings('ignore')
# Load Florence-2 model and processor
model_name = "microsoft/Florence-2-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Modify model loading to disable flash attention
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True
).to(device)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
# Load CivitAI dataset (limited to 1000 samples)
print("Loading dataset...")
dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train[:1000]")
df = pd.DataFrame(dataset)
print("Dataset loaded successfully!")
# Create cache for embeddings to improve performance
text_embedding_cache = {}
def get_image_embedding(image):
try:
inputs = processor(images=image, return_tensors="pt").to(device, torch_dtype)
with torch.no_grad():
outputs = model.get_image_features(**inputs)
return outputs.cpu().numpy()
except Exception as e:
print(f"Error in get_image_embedding: {str(e)}")
return None
def get_text_embedding(text):
try:
if text in text_embedding_cache:
return text_embedding_cache[text]
inputs = processor(text=text, return_tensors="pt").to(device, torch_dtype)
with torch.no_grad():
outputs = model.get_text_features(**inputs)
embedding = outputs.cpu().numpy()
text_embedding_cache[text] = embedding
return embedding
except Exception as e:
print(f"Error in get_text_embedding: {str(e)}")
return None
def precompute_embeddings():
print("Pre-computing text embeddings...")
for idx, row in df.iterrows():
if row['prompt'] not in text_embedding_cache:
_ = get_text_embedding(row['prompt'])
if idx % 100 == 0:
print(f"Processed {idx}/1000 embeddings")
print("Finished pre-computing embeddings")
def find_similar_images(uploaded_image, top_k=5):
query_embedding = get_image_embedding(uploaded_image)
if query_embedding is None:
return [], []
similarities = []
for idx, row in df.iterrows():
prompt_embedding = get_text_embedding(row['prompt'])
if prompt_embedding is not None:
similarity = cosine_similarity(query_embedding, prompt_embedding)[0][0]
similarities.append({
'similarity': similarity,
'model': row['Model'],
'prompt': row['prompt']
})
sorted_results = sorted(similarities, key=lambda x: x['similarity'], reverse=True)
top_models = []
top_prompts = []
seen_models = set()
seen_prompts = set()
for result in sorted_results:
if len(top_models) < top_k and result['model'] not in seen_models:
top_models.append(result['model'])
seen_models.add(result['model'])
if len(top_prompts) < top_k and result['prompt'] not in seen_prompts:
top_prompts.append(result['prompt'])
seen_prompts.add(result['prompt'])
if len(top_models) == top_k and len(top_prompts) == top_k:
break
return top_models, top_prompts
def process_image(input_image):
if input_image is None:
return "Please upload an image.", "Please upload an image."
try:
if not isinstance(input_image, Image.Image):
input_image = Image.fromarray(input_image)
recommended_models, recommended_prompts = find_similar_images(input_image)
if not recommended_models or not recommended_prompts:
return "Error processing image.", "Error processing image."
models_text = "Recommended Models:\n" + "\n".join([f"{i+1}. {model}" for i, model in enumerate(recommended_models)])
prompts_text = "Recommended Prompts:\n" + "\n".join([f"{i+1}. {prompt}" for i, prompt in enumerate(recommended_prompts)])
return models_text, prompts_text
except Exception as e:
print(f"Error in process_image: {str(e)}")
return "Error processing image.", "Error processing image."
# Pre-compute embeddings when starting the application
try:
precompute_embeddings()
except Exception as e:
print(f"Error in precompute_embeddings: {str(e)}")
# Create Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Upload AI-generated image"),
outputs=[
gr.Textbox(label="Recommended Models", lines=6),
gr.Textbox(label="Recommended Prompts", lines=6)
],
title="AI Image Model & Prompt Recommender",
description="Upload an AI-generated image to get recommendations for Stable Diffusion models and prompts.",
examples=[],
cache_examples=False
)
# Launch the interface
iface.launch()