F-allahmoradi commited on
Commit
64367bb
·
verified ·
1 Parent(s): eee453b

Upload core.py

Browse files
Files changed (1) hide show
  1. core.py +68 -0
core.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # core.py
2
+ from ilia3 import extract_text_from_pdf, find_jeld_param
3
+ import os
4
+ from PIL import Image
5
+ import torch
6
+ from transformers import CLIPProcessor, CLIPModel
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ import numpy as np
9
+ import json
10
+
11
+ MODEL_NAME = "openai/clip-vit-base-patch32"
12
+ model = CLIPModel.from_pretrained(MODEL_NAME)
13
+ processor = CLIPProcessor.from_pretrained(MODEL_NAME)
14
+ JSON_PATH = "covers_embeddings.json"
15
+
16
+ def _load_db():
17
+ return json.load(open(JSON_PATH)) if os.path.exists(JSON_PATH) else {}
18
+
19
+ def _save_db(db):
20
+ json.dump(db, open(JSON_PATH, "w"))
21
+
22
+ def _get_embedding(pil_image):
23
+ inputs = processor(images=pil_image, return_tensors="pt")
24
+ with torch.no_grad():
25
+ emb = model.get_image_features(**inputs)
26
+ emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
27
+ return emb.cpu().numpy().squeeze()
28
+
29
+ def analyze_or_save(pdf_path, pil_image, custom_name=None, threshold=0.90):
30
+ base_name = os.path.splitext(os.path.basename(pdf_path))[0]
31
+ key = custom_name.strip() if custom_name else base_name
32
+
33
+ # استخراج متن صفحات ۲ تا ۵
34
+ text = extract_text_from_pdf(pdf_path, pages=(2, 5))
35
+ jeld_param = find_jeld_param(text)
36
+
37
+ if jeld_param:
38
+ key += f"_{jeld_param}"
39
+
40
+ db = _load_db()
41
+ new_emb = _get_embedding(pil_image)
42
+
43
+ if not db:
44
+ db[key] = new_emb.tolist()
45
+ _save_db(db)
46
+ return {"status": "new", "similarity": 0.0, "saved_path": key}
47
+
48
+ keys = list(db.keys())
49
+ embeddings = np.array([np.array(v) for v in db.values()])
50
+ sims = cosine_similarity(new_emb.reshape(1, -1), embeddings)[0]
51
+ max_sim = sims.max()
52
+ max_idx = sims.argmax()
53
+ most_similar_key = keys[max_idx]
54
+
55
+ if max_sim > 0.90:
56
+ return {
57
+ "status": "duplicate",
58
+ "similarity": max_sim * 100,
59
+ "similar_path": most_similar_key
60
+ }
61
+
62
+ db[key] = new_emb.tolist()
63
+ _save_db(db)
64
+ return {
65
+ "status": "new",
66
+ "similarity": max_sim * 100,
67
+ "saved_path": key
68
+ }