hw4 / app.py
mpi's picture
Create app.py
205fdc8 verified
Raw
History Blame Contribute Delete
4.06 kB
import streamlit as st
import torch
from transformers import CLIPProcessor, CLIPModel
import pandas as pd
import numpy as np
from PIL import Image
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@st.cache_resource
def load_clip_model():
try:
model = CLIPModel.from_pretrained("fine_tuned_clip_60")
processor = CLIPProcessor.from_pretrained("fine_tuned_clip_60")
except Exception as e:
st.warning(f"Не удалось загрузить fine-tuned модель: {e}. Загружаем pre-trained модель CLIP.")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.eval()
return model, processor
@st.cache_data
def load_articles(csv_path):
df = pd.read_csv(csv_path)
return df
@st.cache_data
def load_text_embeddings(embeddings_file):
try:
embeddings = np.load(embeddings_file)
return embeddings
except Exception as e:
st.error(f"Ошибка загрузки текстовых эмбеддингов: {e}")
return None
def search_articles(query_image, model, processor, text_embeddings, df, top_k=5):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
inputs = processor(images=query_image, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
image_features = model.get_image_features(**inputs)
image_embed = image_features[0].cpu().numpy()
norm = np.linalg.norm(image_embed)
if norm > 0:
image_embed = image_embed / norm
except Exception as e:
st.error(f"Ошибка при вычислении эмбеддинга изображения: {e}")
return None
sims = np.dot(text_embeddings, image_embed)
top_indices = np.argsort(sims)[::-1][:top_k]
results = df.iloc[top_indices].copy()
results["similarity"] = sims[top_indices]
return results
st.title("Поиск Википедийной статьи по фото")
st.markdown("Загрузите изображение объекта, и система найдёт статьи, наиболее похожие по смыслу на данный объект.")
uploaded_file = st.file_uploader("Выберите изображение (jpg, jpeg, png)", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
query_image = Image.open(uploaded_file).convert("RGB")
st.image(query_image, caption="Загруженное изображение", use_container_width=True)
except Exception as e:
st.error(f"Ошибка при открытии изображения: {e}")
query_image = None
else:
st.info("Пожалуйста, загрузите изображение для поиска.")
model, processor = load_clip_model()
csv_path = "wiki_image_text_pairs_10000.csv" # Update this path if necessary
df_articles = load_articles(csv_path)
embeddings_file = "text_embeddings_60.npy" # Update this path if necessary
text_embeddings = load_text_embeddings(embeddings_file)
if st.button("Найти статьи"):
if query_image is None:
st.error("Изображение не загружено!")
elif text_embeddings is None:
st.error("Текстовые эмбеддинги не загружены!")
else:
results = search_articles(query_image, model, processor, text_embeddings, df_articles, top_k=5)
if results is not None and not results.empty:
st.subheader("Найденные статьи:")
for idx, row in results.iterrows():
st.markdown(f"**{row['title']}** (Сходство: {row['similarity']*100:.1f}%)")
st.write(row['text'][:300] + "...")
st.markdown("---")
else:
st.error("Не удалось найти подходящие статьи.")