academic-recommender-api / api_server.py
siham47's picture
initial deployment
ee87c07
"""
Academic Recommendation API Server
Exposes the recommendation engine as a REST API for n8n integration.
Author: Siham Zaiad Al Kousa (U24200503)
Course: 1501531 Machine Learning
Date: December 2025
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
import json
import numpy as np
import torch
from pathlib import Path
import uvicorn
# SPECTER2 imports
from transformers import AutoTokenizer
from adapters import AutoAdapterModel
from sklearn.metrics.pairwise import cosine_similarity
# ============================================================================
# CONFIGURATION
# ============================================================================
CONFIG = {
'corpus_path': 'data_final/processed/corpus_with_embeddings.json',
'embeddings_path': 'data_final/processed/embeddings.npy',
'specter2_model': 'allenai/specter2_base',
'specter2_adapter': 'allenai/specter2_adhoc_query',
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
'default_top_k': 10,
'max_top_k': 50,
}
# ============================================================================
# PYDANTIC MODELS (Request/Response schemas)
# ============================================================================
class RecommendationRequest(BaseModel):
"""Request schema for recommendations."""
query: str = Field(..., description="Search query")
top_k: int = Field(default=10, ge=1, le=50, description="Number of recommendations")
filter_type: Optional[str] = Field(default=None, description="Filter by 'paper' or 'video'")
year_min: Optional[int] = Field(default=None, description="Minimum publication year")
year_max: Optional[int] = Field(default=None, description="Maximum publication year")
category: Optional[str] = Field(default=None, description="Filter by arXiv category")
min_citations: Optional[int] = Field(default=None, description="Minimum citation count")
class PaperMetadata(BaseModel):
"""Metadata for a single paper."""
paper_id: str
title: str
authors: List[str]
abstract: str
published: str
citations: int
category: str
arxiv_id: Optional[str]
url: Optional[str]
class RecommendationItem(BaseModel):
"""Single recommendation with scores."""
id: str
type: str
title: str
abstract: str
metadata: Dict[str, Any]
scores: Dict[str, float]
rank: int
class RecommendationResponse(BaseModel):
"""Response schema for recommendations."""
query: str
total_results: int
recommendations: List[RecommendationItem]
execution_time_ms: float
# ============================================================================
# SPECTER2 ENCODER
# ============================================================================
class SPECTER2Encoder:
"""SPECTER2 encoder with adhoc_query adapter for queries."""
def __init__(self, model_name: str, adapter_name: str, device: str):
self.device = torch.device(device)
print(f"Loading SPECTER2 model: {model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoAdapterModel.from_pretrained(model_name)
print(f"Loading adapter: {adapter_name}")
self.model.load_adapter(adapter_name, source='hf', set_active=True)
self.model.to(self.device)
self.model.eval()
print(f"βœ“ SPECTER2 ready on {self.device}")
def encode_query(self, query: str) -> np.ndarray:
"""Encode query using adhoc_query adapter."""
inputs = self.tokenizer(
query,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()[0]
return embedding
# ============================================================================
# RECOMMENDATION ENGINE (Simplified)
# ============================================================================
class RecommendationEngine:
"""Simplified recommendation engine for API."""
def __init__(self, corpus_path: str, embeddings_path: str, encoder: SPECTER2Encoder):
# Load corpus
print(f"Loading corpus from: {corpus_path}")
with open(corpus_path, 'r', encoding='utf-8') as f:
corpus_data = json.load(f)
# Extract items from the nested structure
self.corpus = corpus_data.get('items', [])
if not self.corpus:
print("⚠️ Warning: No items found in corpus!")
# Load embeddings
print(f"Loading embeddings from: {embeddings_path}")
self.embeddings = np.load(embeddings_path)
# Store additional metadata if needed
self.corpus_metadata = corpus_data.get('metadata', {})
self.encoder = encoder
print(f"βœ“ Loaded {len(self.corpus)} items")
print(f"βœ“ Embeddings shape: {self.embeddings.shape}")
print(f"βœ“ Corpus metadata keys: {list(self.corpus_metadata.keys())}")
# Recommend method with filtering
def recommend(self,
query: str,
top_k: int = 10,
filter_type: Optional[str] = None,
year_min: Optional[int] = None,
year_max: Optional[int] = None,
category: Optional[str] = None,
min_citations: Optional[int] = None) -> List[Dict]:
"""
Generate recommendations with optional filters.
Returns list of items with scores.
"""
# Encode query
query_embedding = self.encoder.encode_query(query)
# Compute similarities
similarities = cosine_similarity(
query_embedding.reshape(1, -1),
self.embeddings
)[0]
# Score and filter items
scored_items = []
for i, item in enumerate(self.corpus):
# Type filter
item_type = item.get('type', 'paper') # Default to paper
if filter_type and item_type != filter_type:
continue
# Get metadata from your structure
metadata = item.get('metadata', {})
# Year filter - check published date
if year_min or year_max:
pub_date = metadata.get('published', '')
if isinstance(pub_date, str):
# Try to extract year
import re
year_match = re.search(r'\d{4}', pub_date)
if year_match:
try:
year = int(year_match.group())
if year_min and year < year_min:
continue
if year_max and year > year_max:
continue
except (ValueError, TypeError):
pass
# Category filter - check your actual category field
if category:
# Try different possible category fields
item_cat = metadata.get('primary_category', '') or metadata.get('category', '')
if not isinstance(item_cat, str):
item_cat = str(item_cat)
if category.lower() not in item_cat.lower():
continue
# Citation filter
if min_citations:
citations = metadata.get('citationCount', 0) or metadata.get('citations', 0)
if not isinstance(citations, (int, float)):
citations = 0
if citations < min_citations:
continue
# Calculate scores
similarity = float(similarities[i])
# Get impact (citations)
impact = metadata.get('citationCount', 0) or metadata.get('citations', 0)
if not isinstance(impact, (int, float)):
impact = 0
# Get age from fetched_at or published date
age_months = 30.0 # Default
if 'fetched_at' in item:
# You might need to parse the fetched_at date
pass
# Simple recency score (exponential decay)
recency = np.exp(-age_months / 24.0) # Half-life = 24 months
# Weighted final score (60% sim, 20% impact normalized, 20% recency)
impact_normalized = min(impact / 500.0, 1.0) # Cap at 500 citations
final_score = 0.6 * similarity + 0.2 * impact_normalized + 0.2 * recency
# Build the response item based on your actual data structure
scored_items.append({
'id': item.get('id', f'item_{i}'),
'type': item_type,
'title': item.get('title', 'Untitled'),
'abstract': item.get('abstract', '')[:500] or item.get('abstract_cleaned', '')[:500],
'metadata': {
'authors': metadata.get('authors', []),
'published': metadata.get('published', ''),
'citationCount': impact,
'primary_category': metadata.get('primary_category', '') or metadata.get('category', ''),
'arxiv_id': item.get('arxiv_id', ''),
'url': metadata.get('url', '') or metadata.get('pdf_url', ''),
},
'scores': {
'similarity': similarity,
'impact': impact,
'impact_normalized': impact_normalized,
'recency': recency,
'final_score': final_score,
},
})
# Sort by final score
scored_items.sort(key=lambda x: x['scores']['final_score'], reverse=True)
# Return top-K
results = scored_items[:top_k]
# Add rank
for rank, item in enumerate(results, 1):
item['rank'] = rank
return results
# ============================================================================
# FASTAPI APPLICATION
# ============================================================================
app = FastAPI(
title="Academic Recommendation API",
description="LLM-Powered recommendation system for academic papers and videos",
version="1.0.0"
)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global engine instance (loaded on startup)
engine = None
@app.on_event("startup")
async def startup_event():
"""Load model and corpus on startup."""
global engine
print("="*70)
print("STARTING RECOMMENDATION API SERVER")
print("="*70)
try:
# Initialize SPECTER2 encoder
encoder = SPECTER2Encoder(
model_name=CONFIG['specter2_model'],
adapter_name=CONFIG['specter2_adapter'],
device=CONFIG['device']
)
# Initialize recommendation engine
engine = RecommendationEngine(
corpus_path=CONFIG['corpus_path'],
embeddings_path=CONFIG['embeddings_path'],
encoder=encoder
)
print("\nβœ… API Server Ready!")
print(f"Device: {CONFIG['device']}")
print(f"Corpus: {len(engine.corpus)} items")
print("="*70)
except Exception as e:
print(f"\n❌ ERROR during startup: {str(e)}")
raise
@app.get("/")
async def root():
"""Health check endpoint."""
return {
"service": "Academic Recommendation API",
"status": "running",
"version": "1.0.0",
"corpus_size": len(engine.corpus) if engine else 0,
}
@app.get("/health")
async def health():
"""Detailed health check."""
return {
"status": "healthy" if engine else "initializing",
"device": CONFIG['device'],
"model_loaded": engine is not None,
"corpus_loaded": len(engine.corpus) if engine else 0,
}
@app.post("/recommend", response_model=RecommendationResponse)
async def get_recommendations(request: RecommendationRequest):
"""
Get paper/video recommendations for a query.
**Parameters:**
- query: Search query (required)
- top_k: Number of results (1-50, default 10)
- filter_type: Filter by 'paper' or 'video'
- year_min: Minimum publication year
- year_max: Maximum publication year
- category: Filter by arXiv category
- min_citations: Minimum citation count
**Returns:**
- Ranked list of recommendations with scores and metadata
"""
if not engine:
raise HTTPException(status_code=503, detail="Engine not initialized")
try:
import time
start_time = time.time()
# Get recommendations
results = engine.recommend(
query=request.query,
top_k=request.top_k,
filter_type=request.filter_type,
year_min=request.year_min,
year_max=request.year_max,
category=request.category,
min_citations=request.min_citations,
)
# Calculate execution time
execution_time = (time.time() - start_time) * 1000 # Convert to ms
# Format response
response = RecommendationResponse(
query=request.query,
total_results=len(results),
recommendations=[
RecommendationItem(**item) for item in results
],
execution_time_ms=round(execution_time, 2)
)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=f"Recommendation failed: {str(e)}")
@app.get("/stats")
async def get_stats():
"""Get corpus statistics."""
if not engine:
raise HTTPException(status_code=503, detail="Engine not initialized")
papers = [item for item in engine.corpus if item.get('type') == 'paper']
videos = [item for item in engine.corpus if item.get('type') == 'video']
# Category distribution
categories = {}
for paper in papers:
metadata = paper.get('metadata', {})
cat = metadata.get('primary_category', '') or metadata.get('category', 'unknown')
categories[cat] = categories.get(cat, 0) + 1
top_categories = sorted(categories.items(), key=lambda x: x[1], reverse=True)[:10]
return {
"total_items": len(engine.corpus),
"papers": len(papers),
"videos": len(videos),
"top_categories": [{"category": cat, "count": count} for cat, count in top_categories],
"corpus_metadata": engine.corpus_metadata,
}
# ============================================================================
# MAIN
# ============================================================================
if __name__ == "__main__":
print("\nπŸš€ Starting API server...")
print("πŸ“ API docs will be available at: http://localhost:8000/docs")
print("πŸ”§ Health check: http://localhost:8000/health\n")
uvicorn.run(
app,
host="0.0.0.0",
port=8000,
log_level="info"
)