Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| from PIL import Image | |
| import pandas as pd | |
| from datasets import load_dataset | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import numpy as np | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Load Florence-2 model and processor | |
| model_name = "microsoft/Florence-2-base" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # Modify model loading to disable flash attention | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch_dtype, | |
| trust_remote_code=True | |
| ).to(device) | |
| processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) | |
| # Load CivitAI dataset (limited to 1000 samples) | |
| print("Loading dataset...") | |
| dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train[:1000]") | |
| df = pd.DataFrame(dataset) | |
| print("Dataset loaded successfully!") | |
| # Create cache for embeddings to improve performance | |
| text_embedding_cache = {} | |
| def get_image_embedding(image): | |
| try: | |
| inputs = processor(images=image, return_tensors="pt").to(device, torch_dtype) | |
| with torch.no_grad(): | |
| outputs = model.get_image_features(**inputs) | |
| return outputs.cpu().numpy() | |
| except Exception as e: | |
| print(f"Error in get_image_embedding: {str(e)}") | |
| return None | |
| def get_text_embedding(text): | |
| try: | |
| if text in text_embedding_cache: | |
| return text_embedding_cache[text] | |
| inputs = processor(text=text, return_tensors="pt").to(device, torch_dtype) | |
| with torch.no_grad(): | |
| outputs = model.get_text_features(**inputs) | |
| embedding = outputs.cpu().numpy() | |
| text_embedding_cache[text] = embedding | |
| return embedding | |
| except Exception as e: | |
| print(f"Error in get_text_embedding: {str(e)}") | |
| return None | |
| def precompute_embeddings(): | |
| print("Pre-computing text embeddings...") | |
| for idx, row in df.iterrows(): | |
| if row['prompt'] not in text_embedding_cache: | |
| _ = get_text_embedding(row['prompt']) | |
| if idx % 100 == 0: | |
| print(f"Processed {idx}/1000 embeddings") | |
| print("Finished pre-computing embeddings") | |
| def find_similar_images(uploaded_image, top_k=5): | |
| query_embedding = get_image_embedding(uploaded_image) | |
| if query_embedding is None: | |
| return [], [] | |
| similarities = [] | |
| for idx, row in df.iterrows(): | |
| prompt_embedding = get_text_embedding(row['prompt']) | |
| if prompt_embedding is not None: | |
| similarity = cosine_similarity(query_embedding, prompt_embedding)[0][0] | |
| similarities.append({ | |
| 'similarity': similarity, | |
| 'model': row['Model'], | |
| 'prompt': row['prompt'] | |
| }) | |
| sorted_results = sorted(similarities, key=lambda x: x['similarity'], reverse=True) | |
| top_models = [] | |
| top_prompts = [] | |
| seen_models = set() | |
| seen_prompts = set() | |
| for result in sorted_results: | |
| if len(top_models) < top_k and result['model'] not in seen_models: | |
| top_models.append(result['model']) | |
| seen_models.add(result['model']) | |
| if len(top_prompts) < top_k and result['prompt'] not in seen_prompts: | |
| top_prompts.append(result['prompt']) | |
| seen_prompts.add(result['prompt']) | |
| if len(top_models) == top_k and len(top_prompts) == top_k: | |
| break | |
| return top_models, top_prompts | |
| def process_image(input_image): | |
| if input_image is None: | |
| return "Please upload an image.", "Please upload an image." | |
| try: | |
| if not isinstance(input_image, Image.Image): | |
| input_image = Image.fromarray(input_image) | |
| recommended_models, recommended_prompts = find_similar_images(input_image) | |
| if not recommended_models or not recommended_prompts: | |
| return "Error processing image.", "Error processing image." | |
| models_text = "Recommended Models:\n" + "\n".join([f"{i+1}. {model}" for i, model in enumerate(recommended_models)]) | |
| prompts_text = "Recommended Prompts:\n" + "\n".join([f"{i+1}. {prompt}" for i, prompt in enumerate(recommended_prompts)]) | |
| return models_text, prompts_text | |
| except Exception as e: | |
| print(f"Error in process_image: {str(e)}") | |
| return "Error processing image.", "Error processing image." | |
| # Pre-compute embeddings when starting the application | |
| try: | |
| precompute_embeddings() | |
| except Exception as e: | |
| print(f"Error in precompute_embeddings: {str(e)}") | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=process_image, | |
| inputs=gr.Image(type="pil", label="Upload AI-generated image"), | |
| outputs=[ | |
| gr.Textbox(label="Recommended Models", lines=6), | |
| gr.Textbox(label="Recommended Prompts", lines=6) | |
| ], | |
| title="AI Image Model & Prompt Recommender", | |
| description="Upload an AI-generated image to get recommendations for Stable Diffusion models and prompts.", | |
| examples=[], | |
| cache_examples=False | |
| ) | |
| # Launch the interface | |
| iface.launch() |