haajidheere commited on
Commit
56c8fdd
·
verified ·
1 Parent(s): 0e6e402

Add api/search.py

Browse files
Files changed (1) hide show
  1. api/search.py +251 -0
api/search.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Query, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List, Optional
4
+ import csv
5
+ import os
6
+ from sklearn.feature_extraction.text import TfidfVectorizer
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ import numpy as np
9
+
10
+ app = FastAPI(title="ErayNet Search API")
11
+
12
+ DATA_PATH = os.path.join(os.path.dirname(__file__), "..", "data", "cleaned", "abbreviations.csv")
13
+
14
+ class Entry(BaseModel):
15
+ id: int
16
+ raw_text: str
17
+ abbreviation: str
18
+ somali: str
19
+ italian: str
20
+ english: str
21
+ domain: str
22
+ pos: str
23
+ quality_score: float
24
+ review_status: str
25
+ notes: str
26
+
27
+ class SemanticEntry(BaseModel):
28
+ id: int
29
+ raw_text: str
30
+ abbreviation: str
31
+ somali: str
32
+ italian: str
33
+ english: str
34
+ domain: str
35
+ pos: str
36
+ quality_score: float
37
+ review_status: str
38
+ notes: str
39
+ score: float
40
+
41
+ class SemanticSearchResult(BaseModel):
42
+ entries: List[SemanticEntry]
43
+ total: int
44
+ query_type: str
45
+
46
+ class UnifiedSearchResult(BaseModel):
47
+ query: str
48
+ matched_by: str
49
+ entries: List[Entry]
50
+ total: int
51
+
52
+ class SearchResult(BaseModel):
53
+ entries: List[Entry]
54
+ total: int
55
+ query_type: str
56
+
57
+ def load_data():
58
+ entries = []
59
+ with open(DATA_PATH, 'r', encoding='utf-8') as f:
60
+ reader = csv.DictReader(f)
61
+ for row in reader:
62
+ entries.append(Entry(
63
+ id=int(row['id']),
64
+ raw_text=row['raw_text'],
65
+ abbreviation=row['abbreviation'],
66
+ somali=row['somali'],
67
+ italian=row['italian'],
68
+ english=row['english'],
69
+ domain=row['domain'],
70
+ pos=row['pos'],
71
+ quality_score=float(row['quality_score']),
72
+ review_status=row['review_status'],
73
+ notes=row['notes']
74
+ ))
75
+ return entries
76
+
77
+ def build_search_index(entries):
78
+ documents = []
79
+ for e in entries:
80
+ doc = f"{e.abbreviation} {e.somali} {e.italian} {e.english} {e.raw_text}"
81
+ documents.append(doc)
82
+
83
+ vectorizer = TfidfVectorizer(analyzer='char_wb', ngram_range=(2, 4))
84
+ tfidf_matrix = vectorizer.fit_transform(documents)
85
+ return vectorizer, tfidf_matrix
86
+
87
+ entries = load_data()
88
+ vectorizer, tfidf_matrix = build_search_index(entries)
89
+
90
+ @app.get("/search/exact", response_model=SearchResult)
91
+ def exact_match(
92
+ q: str = Query(..., description="Query string"),
93
+ domain: Optional[str] = Query(None, description="Filter by domain"),
94
+ pos: Optional[str] = Query(None, description="Filter by part of speech"),
95
+ review_status: Optional[str] = Query(None, description="Filter by review status")
96
+ ):
97
+ q = q.lower().strip()
98
+ results = [
99
+ e for e in entries
100
+ if (q == e.abbreviation.lower() or q == e.somali.lower() or q == e.italian.lower() or q == e.english.lower())
101
+ and (domain is None or e.domain.lower() == domain.lower())
102
+ and (pos is None or e.pos.lower() == pos.lower())
103
+ and (review_status is None or e.review_status.lower() == review_status.lower())
104
+ ]
105
+ return SearchResult(entries=results, total=len(results), query_type="exact")
106
+
107
+ @app.get("/search/partial", response_model=SearchResult)
108
+ def partial_match(
109
+ q: str = Query(..., description="Query string"),
110
+ domain: Optional[str] = Query(None, description="Filter by domain"),
111
+ pos: Optional[str] = Query(None, description="Filter by part of speech"),
112
+ review_status: Optional[str] = Query(None, description="Filter by review status")
113
+ ):
114
+ q = q.lower().strip()
115
+ results = [
116
+ e for e in entries
117
+ if (q in e.abbreviation.lower() or q in e.somali.lower() or q in e.italian.lower() or q in e.english.lower())
118
+ and (domain is None or e.domain.lower() == domain.lower())
119
+ and (pos is None or e.pos.lower() == pos.lower())
120
+ and (review_status is None or e.review_status.lower() == review_status.lower())
121
+ ]
122
+ return SearchResult(entries=results, total=len(results), query_type="partial")
123
+
124
+ @app.get("/search/semantic", response_model=SemanticSearchResult)
125
+ def semantic_search(
126
+ q: str = Query(..., description="Query string"),
127
+ top_k: int = Query(5, ge=1, le=20),
128
+ domain: Optional[str] = Query(None, description="Filter by domain"),
129
+ pos: Optional[str] = Query(None, description="Filter by part of speech"),
130
+ review_status: Optional[str] = Query(None, description="Filter by review status")
131
+ ):
132
+ query_vec = vectorizer.transform([q])
133
+ similarities = cosine_similarity(query_vec, tfidf_matrix).flatten()
134
+
135
+ filtered_indices = []
136
+ for i, e in enumerate(entries):
137
+ if similarities[i] > 0:
138
+ if (domain is None or e.domain.lower() == domain.lower()) and \
139
+ (pos is None or e.pos.lower() == pos.lower()) and \
140
+ (review_status is None or e.review_status.lower() == review_status.lower()):
141
+ filtered_indices.append(i)
142
+
143
+ filtered_indices.sort(key=lambda i: similarities[i], reverse=True)
144
+ top_indices = filtered_indices[:top_k]
145
+
146
+ results = [
147
+ SemanticEntry(
148
+ id=entries[i].id,
149
+ raw_text=entries[i].raw_text,
150
+ abbreviation=entries[i].abbreviation,
151
+ somali=entries[i].somali,
152
+ italian=entries[i].italian,
153
+ english=entries[i].english,
154
+ domain=entries[i].domain,
155
+ pos=entries[i].pos,
156
+ quality_score=entries[i].quality_score,
157
+ review_status=entries[i].review_status,
158
+ notes=entries[i].notes,
159
+ score=round(float(similarities[i]), 2)
160
+ )
161
+ for i in top_indices
162
+ ]
163
+ return SemanticSearchResult(entries=results, total=len(results), query_type="semantic")
164
+
165
+ @app.get("/search", response_model=UnifiedSearchResult)
166
+ def unified_search(
167
+ q: str = Query(..., description="Query string"),
168
+ domain: Optional[str] = Query(None, description="Filter by domain"),
169
+ pos: Optional[str] = Query(None, description="Filter by part of speech"),
170
+ review_status: Optional[str] = Query(None, description="Filter by review status")
171
+ ):
172
+ q_lower = q.lower().strip()
173
+
174
+ def matches_filters(e):
175
+ return (domain is None or e.domain.lower() == domain.lower()) and \
176
+ (pos is None or e.pos.lower() == pos.lower()) and \
177
+ (review_status is None or e.review_status.lower() == review_status.lower())
178
+
179
+ exact_results = [
180
+ e for e in entries
181
+ if (q_lower == e.abbreviation.lower() or q_lower == e.somali.lower() or q_lower == e.italian.lower() or q_lower == e.english.lower())
182
+ and matches_filters(e)
183
+ ]
184
+ if exact_results:
185
+ return UnifiedSearchResult(query=q, matched_by="exact", entries=exact_results, total=len(exact_results))
186
+
187
+ partial_results = [
188
+ e for e in entries
189
+ if (q_lower in e.abbreviation.lower() or q_lower in e.somali.lower() or q_lower in e.italian.lower() or q_lower in e.english.lower())
190
+ and matches_filters(e)
191
+ ]
192
+ if partial_results:
193
+ return UnifiedSearchResult(query=q, matched_by="partial", entries=partial_results, total=len(partial_results))
194
+
195
+ query_vec = vectorizer.transform([q])
196
+ similarities = cosine_similarity(query_vec, tfidf_matrix).flatten()
197
+
198
+ filtered_indices = [
199
+ i for i in range(len(entries))
200
+ if similarities[i] > 0 and matches_filters(entries[i])
201
+ ]
202
+ filtered_indices.sort(key=lambda i: similarities[i], reverse=True)
203
+ top_indices = filtered_indices[:5]
204
+ semantic_results = [entries[i] for i in top_indices]
205
+
206
+ return UnifiedSearchResult(query=q, matched_by="semantic", entries=semantic_results, total=len(semantic_results))
207
+
208
+ @app.get("/entries", response_model=List[Entry])
209
+ def list_entries(skip: int = 0, limit: int = 100):
210
+ return entries[skip:skip+limit]
211
+
212
+ @app.get("/entries/{entry_id}", response_model=Entry)
213
+ def get_entry(entry_id: int):
214
+ for e in entries:
215
+ if e.id == entry_id:
216
+ return e
217
+ raise HTTPException(status_code=404, detail="Entry not found")
218
+
219
+ @app.get("/domains")
220
+ def list_domains():
221
+ domains = sorted(set(e.domain for e in entries if e.domain))
222
+ return {"domains": domains}
223
+
224
+ @app.get("/pos-tags")
225
+ def list_pos_tags():
226
+ pos_tags = sorted(set(e.pos for e in entries if e.pos))
227
+ return {"pos_tags": pos_tags}
228
+
229
+ @app.get("/stats")
230
+ def get_stats():
231
+ total = len(entries)
232
+ domains = {}
233
+ pos_tags = {}
234
+ review_statuses = {}
235
+ for e in entries:
236
+ if e.domain:
237
+ domains[e.domain] = domains.get(e.domain, 0) + 1
238
+ if e.pos:
239
+ pos_tags[e.pos] = pos_tags.get(e.pos, 0) + 1
240
+ if e.review_status:
241
+ review_statuses[e.review_status] = review_statuses.get(e.review_status, 0) + 1
242
+ return {
243
+ "total_entries": total,
244
+ "domains": dict(sorted(domains.items(), key=lambda x: -x[1])),
245
+ "pos_tags": dict(sorted(pos_tags.items(), key=lambda x: -x[1])),
246
+ "review_statuses": dict(sorted(review_statuses.items(), key=lambda x: -x[1]))
247
+ }
248
+
249
+ if __name__ == "__main__":
250
+ import uvicorn
251
+ uvicorn.run(app, host="0.0.0.0", port=8000)