| import streamlit as st |
| import torch |
| from transformers import CLIPProcessor, CLIPModel |
| import pandas as pd |
| import numpy as np |
| from PIL import Image |
|
|
| import os |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| @st.cache_resource |
| def load_clip_model(): |
| try: |
| model = CLIPModel.from_pretrained("fine_tuned_clip_60") |
| processor = CLIPProcessor.from_pretrained("fine_tuned_clip_60") |
| except Exception as e: |
| st.warning(f"Не удалось загрузить fine-tuned модель: {e}. Загружаем pre-trained модель CLIP.") |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
| model.eval() |
| return model, processor |
|
|
| @st.cache_data |
| def load_articles(csv_path): |
| df = pd.read_csv(csv_path) |
| return df |
|
|
| @st.cache_data |
| def load_text_embeddings(embeddings_file): |
| try: |
| embeddings = np.load(embeddings_file) |
| return embeddings |
| except Exception as e: |
| st.error(f"Ошибка загрузки текстовых эмбеддингов: {e}") |
| return None |
|
|
| def search_articles(query_image, model, processor, text_embeddings, df, top_k=5): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| try: |
| inputs = processor(images=query_image, return_tensors="pt", padding=True, truncation=True) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| image_features = model.get_image_features(**inputs) |
| image_embed = image_features[0].cpu().numpy() |
| norm = np.linalg.norm(image_embed) |
| if norm > 0: |
| image_embed = image_embed / norm |
| except Exception as e: |
| st.error(f"Ошибка при вычислении эмбеддинга изображения: {e}") |
| return None |
|
|
| sims = np.dot(text_embeddings, image_embed) |
| top_indices = np.argsort(sims)[::-1][:top_k] |
| results = df.iloc[top_indices].copy() |
| results["similarity"] = sims[top_indices] |
| return results |
|
|
| st.title("Поиск Википедийной статьи по фото") |
| st.markdown("Загрузите изображение объекта, и система найдёт статьи, наиболее похожие по смыслу на данный объект.") |
|
|
| uploaded_file = st.file_uploader("Выберите изображение (jpg, jpeg, png)", type=["jpg", "jpeg", "png"]) |
|
|
| if uploaded_file is not None: |
| try: |
| query_image = Image.open(uploaded_file).convert("RGB") |
| st.image(query_image, caption="Загруженное изображение", use_container_width=True) |
| except Exception as e: |
| st.error(f"Ошибка при открытии изображения: {e}") |
| query_image = None |
| else: |
| st.info("Пожалуйста, загрузите изображение для поиска.") |
|
|
| model, processor = load_clip_model() |
| csv_path = "wiki_image_text_pairs_10000.csv" |
| df_articles = load_articles(csv_path) |
| embeddings_file = "text_embeddings_60.npy" |
| text_embeddings = load_text_embeddings(embeddings_file) |
|
|
| if st.button("Найти статьи"): |
| if query_image is None: |
| st.error("Изображение не загружено!") |
| elif text_embeddings is None: |
| st.error("Текстовые эмбеддинги не загружены!") |
| else: |
| results = search_articles(query_image, model, processor, text_embeddings, df_articles, top_k=5) |
| if results is not None and not results.empty: |
| st.subheader("Найденные статьи:") |
| for idx, row in results.iterrows(): |
| st.markdown(f"**{row['title']}** (Сходство: {row['similarity']*100:.1f}%)") |
| st.write(row['text'][:300] + "...") |
| st.markdown("---") |
| else: |
| st.error("Не удалось найти подходящие статьи.") |
|
|