import streamlit as st import torch from transformers import CLIPProcessor, CLIPModel import pandas as pd import numpy as np from PIL import Image import os os.environ["TOKENIZERS_PARALLELISM"] = "false" @st.cache_resource def load_clip_model(): try: model = CLIPModel.from_pretrained("fine_tuned_clip_60") processor = CLIPProcessor.from_pretrained("fine_tuned_clip_60") except Exception as e: st.warning(f"Не удалось загрузить fine-tuned модель: {e}. Загружаем pre-trained модель CLIP.") model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") model.eval() return model, processor @st.cache_data def load_articles(csv_path): df = pd.read_csv(csv_path) return df @st.cache_data def load_text_embeddings(embeddings_file): try: embeddings = np.load(embeddings_file) return embeddings except Exception as e: st.error(f"Ошибка загрузки текстовых эмбеддингов: {e}") return None def search_articles(query_image, model, processor, text_embeddings, df, top_k=5): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") try: inputs = processor(images=query_image, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): image_features = model.get_image_features(**inputs) image_embed = image_features[0].cpu().numpy() norm = np.linalg.norm(image_embed) if norm > 0: image_embed = image_embed / norm except Exception as e: st.error(f"Ошибка при вычислении эмбеддинга изображения: {e}") return None sims = np.dot(text_embeddings, image_embed) top_indices = np.argsort(sims)[::-1][:top_k] results = df.iloc[top_indices].copy() results["similarity"] = sims[top_indices] return results st.title("Поиск Википедийной статьи по фото") st.markdown("Загрузите изображение объекта, и система найдёт статьи, наиболее похожие по смыслу на данный объект.") uploaded_file = st.file_uploader("Выберите изображение (jpg, jpeg, png)", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: try: query_image = Image.open(uploaded_file).convert("RGB") st.image(query_image, caption="Загруженное изображение", use_container_width=True) except Exception as e: st.error(f"Ошибка при открытии изображения: {e}") query_image = None else: st.info("Пожалуйста, загрузите изображение для поиска.") model, processor = load_clip_model() csv_path = "wiki_image_text_pairs_10000.csv" # Update this path if necessary df_articles = load_articles(csv_path) embeddings_file = "text_embeddings_60.npy" # Update this path if necessary text_embeddings = load_text_embeddings(embeddings_file) if st.button("Найти статьи"): if query_image is None: st.error("Изображение не загружено!") elif text_embeddings is None: st.error("Текстовые эмбеддинги не загружены!") else: results = search_articles(query_image, model, processor, text_embeddings, df_articles, top_k=5) if results is not None and not results.empty: st.subheader("Найденные статьи:") for idx, row in results.iterrows(): st.markdown(f"**{row['title']}** (Сходство: {row['similarity']*100:.1f}%)") st.write(row['text'][:300] + "...") st.markdown("---") else: st.error("Не удалось найти подходящие статьи.")