search_jurist2 / app.py
ramdane's picture
Update app.py
e13bc8e
import pickle
import os
import torch
print(os.getcwd())
# Load embeddings
fileobj = open("/home/user/app/embmmn7.obj", "rb")
corpus_embeddings, corpus = pickle.load(fileobj)
fileobj.close()
from sentence_transformers import SentenceTransformer, util
embedder = SentenceTransformer("ramdane/jurimodel")
# Add normalize_embeddings function
def normalize_embeddings(embeddings):
"""Normalize embeddings to unit length, handling both 1D and 2D tensors"""
if isinstance(embeddings, torch.Tensor):
# Check if it's a 1D tensor (single embedding)
if len(embeddings.shape) == 1:
# Add batch dimension, normalize, then squeeze
return torch.nn.functional.normalize(embeddings.unsqueeze(0), p=2, dim=1).squeeze(0)
else:
# It's already 2D, normalize as usual
return torch.nn.functional.normalize(embeddings, p=2, dim=1)
else:
# Convert to tensor first
embeddings_tensor = torch.tensor(embeddings)
if len(embeddings_tensor.shape) == 1:
return torch.nn.functional.normalize(embeddings_tensor.unsqueeze(0), p=2, dim=1).squeeze(0)
else:
return torch.nn.functional.normalize(embeddings_tensor, p=2, dim=1)
# Ensure corpus embeddings are normalized when loaded
if not isinstance(corpus_embeddings, torch.Tensor):
corpus_embeddings = torch.tensor(corpus_embeddings)
corpus_embeddings = normalize_embeddings(corpus_embeddings)
def showr(queries, number):
# Create query embedding
query_embedding = embedder.encode(queries, convert_to_tensor=True)
# Normalize query embedding
query_embedding = normalize_embeddings(query_embedding)
# Perform semantic search
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=10)
hits = hits[0] # Get the hits for the first query
if hits[number]['score'] > 0.05:
return corpus[hits[number]['corpus_id']]
else:
return "لم نتمكن من ايجاد النتيجة اما لعدم وجود الاجتهاد او لعدم كتابة جملة بحث مناسبة "
import gradio as gr
def greet(search_for, number):
return showr(search_for, int(number))
iface = gr.Interface(
fn=greet,
inputs=[
gr.Textbox(label="ادخل كلمات البحث"),
gr.Number(label="الترتيب")
],
outputs=gr.TextArea(label="الاجتهاد")
)
iface.launch()