ErayNet-nirig / api /search.py
haajidheere's picture
Add api/search.py
56c8fdd verified
from fastapi import FastAPI, Query, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import csv
import os
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
app = FastAPI(title="ErayNet Search API")
DATA_PATH = os.path.join(os.path.dirname(__file__), "..", "data", "cleaned", "abbreviations.csv")
class Entry(BaseModel):
id: int
raw_text: str
abbreviation: str
somali: str
italian: str
english: str
domain: str
pos: str
quality_score: float
review_status: str
notes: str
class SemanticEntry(BaseModel):
id: int
raw_text: str
abbreviation: str
somali: str
italian: str
english: str
domain: str
pos: str
quality_score: float
review_status: str
notes: str
score: float
class SemanticSearchResult(BaseModel):
entries: List[SemanticEntry]
total: int
query_type: str
class UnifiedSearchResult(BaseModel):
query: str
matched_by: str
entries: List[Entry]
total: int
class SearchResult(BaseModel):
entries: List[Entry]
total: int
query_type: str
def load_data():
entries = []
with open(DATA_PATH, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
entries.append(Entry(
id=int(row['id']),
raw_text=row['raw_text'],
abbreviation=row['abbreviation'],
somali=row['somali'],
italian=row['italian'],
english=row['english'],
domain=row['domain'],
pos=row['pos'],
quality_score=float(row['quality_score']),
review_status=row['review_status'],
notes=row['notes']
))
return entries
def build_search_index(entries):
documents = []
for e in entries:
doc = f"{e.abbreviation} {e.somali} {e.italian} {e.english} {e.raw_text}"
documents.append(doc)
vectorizer = TfidfVectorizer(analyzer='char_wb', ngram_range=(2, 4))
tfidf_matrix = vectorizer.fit_transform(documents)
return vectorizer, tfidf_matrix
entries = load_data()
vectorizer, tfidf_matrix = build_search_index(entries)
@app.get("/search/exact", response_model=SearchResult)
def exact_match(
q: str = Query(..., description="Query string"),
domain: Optional[str] = Query(None, description="Filter by domain"),
pos: Optional[str] = Query(None, description="Filter by part of speech"),
review_status: Optional[str] = Query(None, description="Filter by review status")
):
q = q.lower().strip()
results = [
e for e in entries
if (q == e.abbreviation.lower() or q == e.somali.lower() or q == e.italian.lower() or q == e.english.lower())
and (domain is None or e.domain.lower() == domain.lower())
and (pos is None or e.pos.lower() == pos.lower())
and (review_status is None or e.review_status.lower() == review_status.lower())
]
return SearchResult(entries=results, total=len(results), query_type="exact")
@app.get("/search/partial", response_model=SearchResult)
def partial_match(
q: str = Query(..., description="Query string"),
domain: Optional[str] = Query(None, description="Filter by domain"),
pos: Optional[str] = Query(None, description="Filter by part of speech"),
review_status: Optional[str] = Query(None, description="Filter by review status")
):
q = q.lower().strip()
results = [
e for e in entries
if (q in e.abbreviation.lower() or q in e.somali.lower() or q in e.italian.lower() or q in e.english.lower())
and (domain is None or e.domain.lower() == domain.lower())
and (pos is None or e.pos.lower() == pos.lower())
and (review_status is None or e.review_status.lower() == review_status.lower())
]
return SearchResult(entries=results, total=len(results), query_type="partial")
@app.get("/search/semantic", response_model=SemanticSearchResult)
def semantic_search(
q: str = Query(..., description="Query string"),
top_k: int = Query(5, ge=1, le=20),
domain: Optional[str] = Query(None, description="Filter by domain"),
pos: Optional[str] = Query(None, description="Filter by part of speech"),
review_status: Optional[str] = Query(None, description="Filter by review status")
):
query_vec = vectorizer.transform([q])
similarities = cosine_similarity(query_vec, tfidf_matrix).flatten()
filtered_indices = []
for i, e in enumerate(entries):
if similarities[i] > 0:
if (domain is None or e.domain.lower() == domain.lower()) and \
(pos is None or e.pos.lower() == pos.lower()) and \
(review_status is None or e.review_status.lower() == review_status.lower()):
filtered_indices.append(i)
filtered_indices.sort(key=lambda i: similarities[i], reverse=True)
top_indices = filtered_indices[:top_k]
results = [
SemanticEntry(
id=entries[i].id,
raw_text=entries[i].raw_text,
abbreviation=entries[i].abbreviation,
somali=entries[i].somali,
italian=entries[i].italian,
english=entries[i].english,
domain=entries[i].domain,
pos=entries[i].pos,
quality_score=entries[i].quality_score,
review_status=entries[i].review_status,
notes=entries[i].notes,
score=round(float(similarities[i]), 2)
)
for i in top_indices
]
return SemanticSearchResult(entries=results, total=len(results), query_type="semantic")
@app.get("/search", response_model=UnifiedSearchResult)
def unified_search(
q: str = Query(..., description="Query string"),
domain: Optional[str] = Query(None, description="Filter by domain"),
pos: Optional[str] = Query(None, description="Filter by part of speech"),
review_status: Optional[str] = Query(None, description="Filter by review status")
):
q_lower = q.lower().strip()
def matches_filters(e):
return (domain is None or e.domain.lower() == domain.lower()) and \
(pos is None or e.pos.lower() == pos.lower()) and \
(review_status is None or e.review_status.lower() == review_status.lower())
exact_results = [
e for e in entries
if (q_lower == e.abbreviation.lower() or q_lower == e.somali.lower() or q_lower == e.italian.lower() or q_lower == e.english.lower())
and matches_filters(e)
]
if exact_results:
return UnifiedSearchResult(query=q, matched_by="exact", entries=exact_results, total=len(exact_results))
partial_results = [
e for e in entries
if (q_lower in e.abbreviation.lower() or q_lower in e.somali.lower() or q_lower in e.italian.lower() or q_lower in e.english.lower())
and matches_filters(e)
]
if partial_results:
return UnifiedSearchResult(query=q, matched_by="partial", entries=partial_results, total=len(partial_results))
query_vec = vectorizer.transform([q])
similarities = cosine_similarity(query_vec, tfidf_matrix).flatten()
filtered_indices = [
i for i in range(len(entries))
if similarities[i] > 0 and matches_filters(entries[i])
]
filtered_indices.sort(key=lambda i: similarities[i], reverse=True)
top_indices = filtered_indices[:5]
semantic_results = [entries[i] for i in top_indices]
return UnifiedSearchResult(query=q, matched_by="semantic", entries=semantic_results, total=len(semantic_results))
@app.get("/entries", response_model=List[Entry])
def list_entries(skip: int = 0, limit: int = 100):
return entries[skip:skip+limit]
@app.get("/entries/{entry_id}", response_model=Entry)
def get_entry(entry_id: int):
for e in entries:
if e.id == entry_id:
return e
raise HTTPException(status_code=404, detail="Entry not found")
@app.get("/domains")
def list_domains():
domains = sorted(set(e.domain for e in entries if e.domain))
return {"domains": domains}
@app.get("/pos-tags")
def list_pos_tags():
pos_tags = sorted(set(e.pos for e in entries if e.pos))
return {"pos_tags": pos_tags}
@app.get("/stats")
def get_stats():
total = len(entries)
domains = {}
pos_tags = {}
review_statuses = {}
for e in entries:
if e.domain:
domains[e.domain] = domains.get(e.domain, 0) + 1
if e.pos:
pos_tags[e.pos] = pos_tags.get(e.pos, 0) + 1
if e.review_status:
review_statuses[e.review_status] = review_statuses.get(e.review_status, 0) + 1
return {
"total_entries": total,
"domains": dict(sorted(domains.items(), key=lambda x: -x[1])),
"pos_tags": dict(sorted(pos_tags.items(), key=lambda x: -x[1])),
"review_statuses": dict(sorted(review_statuses.items(), key=lambda x: -x[1]))
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)