Spaces:
Sleeping
Sleeping
| """ | |
| Academic Recommendation API Server | |
| Exposes the recommendation engine as a REST API for n8n integration. | |
| Author: Siham Zaiad Al Kousa (U24200503) | |
| Course: 1501531 Machine Learning | |
| Date: December 2025 | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from typing import List, Optional, Dict, Any | |
| import json | |
| import numpy as np | |
| import torch | |
| from pathlib import Path | |
| import uvicorn | |
| # SPECTER2 imports | |
| from transformers import AutoTokenizer | |
| from adapters import AutoAdapterModel | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| # ============================================================================ | |
| # CONFIGURATION | |
| # ============================================================================ | |
| CONFIG = { | |
| 'corpus_path': 'data_final/processed/corpus_with_embeddings.json', | |
| 'embeddings_path': 'data_final/processed/embeddings.npy', | |
| 'specter2_model': 'allenai/specter2_base', | |
| 'specter2_adapter': 'allenai/specter2_adhoc_query', | |
| 'device': 'cuda' if torch.cuda.is_available() else 'cpu', | |
| 'default_top_k': 10, | |
| 'max_top_k': 50, | |
| } | |
| # ============================================================================ | |
| # PYDANTIC MODELS (Request/Response schemas) | |
| # ============================================================================ | |
| class RecommendationRequest(BaseModel): | |
| """Request schema for recommendations.""" | |
| query: str = Field(..., description="Search query") | |
| top_k: int = Field(default=10, ge=1, le=50, description="Number of recommendations") | |
| filter_type: Optional[str] = Field(default=None, description="Filter by 'paper' or 'video'") | |
| year_min: Optional[int] = Field(default=None, description="Minimum publication year") | |
| year_max: Optional[int] = Field(default=None, description="Maximum publication year") | |
| category: Optional[str] = Field(default=None, description="Filter by arXiv category") | |
| min_citations: Optional[int] = Field(default=None, description="Minimum citation count") | |
| class PaperMetadata(BaseModel): | |
| """Metadata for a single paper.""" | |
| paper_id: str | |
| title: str | |
| authors: List[str] | |
| abstract: str | |
| published: str | |
| citations: int | |
| category: str | |
| arxiv_id: Optional[str] | |
| url: Optional[str] | |
| class RecommendationItem(BaseModel): | |
| """Single recommendation with scores.""" | |
| id: str | |
| type: str | |
| title: str | |
| abstract: str | |
| metadata: Dict[str, Any] | |
| scores: Dict[str, float] | |
| rank: int | |
| class RecommendationResponse(BaseModel): | |
| """Response schema for recommendations.""" | |
| query: str | |
| total_results: int | |
| recommendations: List[RecommendationItem] | |
| execution_time_ms: float | |
| # ============================================================================ | |
| # SPECTER2 ENCODER | |
| # ============================================================================ | |
| class SPECTER2Encoder: | |
| """SPECTER2 encoder with adhoc_query adapter for queries.""" | |
| def __init__(self, model_name: str, adapter_name: str, device: str): | |
| self.device = torch.device(device) | |
| print(f"Loading SPECTER2 model: {model_name}") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoAdapterModel.from_pretrained(model_name) | |
| print(f"Loading adapter: {adapter_name}") | |
| self.model.load_adapter(adapter_name, source='hf', set_active=True) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| print(f"β SPECTER2 ready on {self.device}") | |
| def encode_query(self, query: str) -> np.ndarray: | |
| """Encode query using adhoc_query adapter.""" | |
| inputs = self.tokenizer( | |
| query, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()[0] | |
| return embedding | |
| # ============================================================================ | |
| # RECOMMENDATION ENGINE (Simplified) | |
| # ============================================================================ | |
| class RecommendationEngine: | |
| """Simplified recommendation engine for API.""" | |
| def __init__(self, corpus_path: str, embeddings_path: str, encoder: SPECTER2Encoder): | |
| # Load corpus | |
| print(f"Loading corpus from: {corpus_path}") | |
| with open(corpus_path, 'r', encoding='utf-8') as f: | |
| corpus_data = json.load(f) | |
| # Extract items from the nested structure | |
| self.corpus = corpus_data.get('items', []) | |
| if not self.corpus: | |
| print("β οΈ Warning: No items found in corpus!") | |
| # Load embeddings | |
| print(f"Loading embeddings from: {embeddings_path}") | |
| self.embeddings = np.load(embeddings_path) | |
| # Store additional metadata if needed | |
| self.corpus_metadata = corpus_data.get('metadata', {}) | |
| self.encoder = encoder | |
| print(f"β Loaded {len(self.corpus)} items") | |
| print(f"β Embeddings shape: {self.embeddings.shape}") | |
| print(f"β Corpus metadata keys: {list(self.corpus_metadata.keys())}") | |
| # Recommend method with filtering | |
| def recommend(self, | |
| query: str, | |
| top_k: int = 10, | |
| filter_type: Optional[str] = None, | |
| year_min: Optional[int] = None, | |
| year_max: Optional[int] = None, | |
| category: Optional[str] = None, | |
| min_citations: Optional[int] = None) -> List[Dict]: | |
| """ | |
| Generate recommendations with optional filters. | |
| Returns list of items with scores. | |
| """ | |
| # Encode query | |
| query_embedding = self.encoder.encode_query(query) | |
| # Compute similarities | |
| similarities = cosine_similarity( | |
| query_embedding.reshape(1, -1), | |
| self.embeddings | |
| )[0] | |
| # Score and filter items | |
| scored_items = [] | |
| for i, item in enumerate(self.corpus): | |
| # Type filter | |
| item_type = item.get('type', 'paper') # Default to paper | |
| if filter_type and item_type != filter_type: | |
| continue | |
| # Get metadata from your structure | |
| metadata = item.get('metadata', {}) | |
| # Year filter - check published date | |
| if year_min or year_max: | |
| pub_date = metadata.get('published', '') | |
| if isinstance(pub_date, str): | |
| # Try to extract year | |
| import re | |
| year_match = re.search(r'\d{4}', pub_date) | |
| if year_match: | |
| try: | |
| year = int(year_match.group()) | |
| if year_min and year < year_min: | |
| continue | |
| if year_max and year > year_max: | |
| continue | |
| except (ValueError, TypeError): | |
| pass | |
| # Category filter - check your actual category field | |
| if category: | |
| # Try different possible category fields | |
| item_cat = metadata.get('primary_category', '') or metadata.get('category', '') | |
| if not isinstance(item_cat, str): | |
| item_cat = str(item_cat) | |
| if category.lower() not in item_cat.lower(): | |
| continue | |
| # Citation filter | |
| if min_citations: | |
| citations = metadata.get('citationCount', 0) or metadata.get('citations', 0) | |
| if not isinstance(citations, (int, float)): | |
| citations = 0 | |
| if citations < min_citations: | |
| continue | |
| # Calculate scores | |
| similarity = float(similarities[i]) | |
| # Get impact (citations) | |
| impact = metadata.get('citationCount', 0) or metadata.get('citations', 0) | |
| if not isinstance(impact, (int, float)): | |
| impact = 0 | |
| # Get age from fetched_at or published date | |
| age_months = 30.0 # Default | |
| if 'fetched_at' in item: | |
| # You might need to parse the fetched_at date | |
| pass | |
| # Simple recency score (exponential decay) | |
| recency = np.exp(-age_months / 24.0) # Half-life = 24 months | |
| # Weighted final score (60% sim, 20% impact normalized, 20% recency) | |
| impact_normalized = min(impact / 500.0, 1.0) # Cap at 500 citations | |
| final_score = 0.6 * similarity + 0.2 * impact_normalized + 0.2 * recency | |
| # Build the response item based on your actual data structure | |
| scored_items.append({ | |
| 'id': item.get('id', f'item_{i}'), | |
| 'type': item_type, | |
| 'title': item.get('title', 'Untitled'), | |
| 'abstract': item.get('abstract', '')[:500] or item.get('abstract_cleaned', '')[:500], | |
| 'metadata': { | |
| 'authors': metadata.get('authors', []), | |
| 'published': metadata.get('published', ''), | |
| 'citationCount': impact, | |
| 'primary_category': metadata.get('primary_category', '') or metadata.get('category', ''), | |
| 'arxiv_id': item.get('arxiv_id', ''), | |
| 'url': metadata.get('url', '') or metadata.get('pdf_url', ''), | |
| }, | |
| 'scores': { | |
| 'similarity': similarity, | |
| 'impact': impact, | |
| 'impact_normalized': impact_normalized, | |
| 'recency': recency, | |
| 'final_score': final_score, | |
| }, | |
| }) | |
| # Sort by final score | |
| scored_items.sort(key=lambda x: x['scores']['final_score'], reverse=True) | |
| # Return top-K | |
| results = scored_items[:top_k] | |
| # Add rank | |
| for rank, item in enumerate(results, 1): | |
| item['rank'] = rank | |
| return results | |
| # ============================================================================ | |
| # FASTAPI APPLICATION | |
| # ============================================================================ | |
| app = FastAPI( | |
| title="Academic Recommendation API", | |
| description="LLM-Powered recommendation system for academic papers and videos", | |
| version="1.0.0" | |
| ) | |
| # Enable CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global engine instance (loaded on startup) | |
| engine = None | |
| async def startup_event(): | |
| """Load model and corpus on startup.""" | |
| global engine | |
| print("="*70) | |
| print("STARTING RECOMMENDATION API SERVER") | |
| print("="*70) | |
| try: | |
| # Initialize SPECTER2 encoder | |
| encoder = SPECTER2Encoder( | |
| model_name=CONFIG['specter2_model'], | |
| adapter_name=CONFIG['specter2_adapter'], | |
| device=CONFIG['device'] | |
| ) | |
| # Initialize recommendation engine | |
| engine = RecommendationEngine( | |
| corpus_path=CONFIG['corpus_path'], | |
| embeddings_path=CONFIG['embeddings_path'], | |
| encoder=encoder | |
| ) | |
| print("\nβ API Server Ready!") | |
| print(f"Device: {CONFIG['device']}") | |
| print(f"Corpus: {len(engine.corpus)} items") | |
| print("="*70) | |
| except Exception as e: | |
| print(f"\nβ ERROR during startup: {str(e)}") | |
| raise | |
| async def root(): | |
| """Health check endpoint.""" | |
| return { | |
| "service": "Academic Recommendation API", | |
| "status": "running", | |
| "version": "1.0.0", | |
| "corpus_size": len(engine.corpus) if engine else 0, | |
| } | |
| async def health(): | |
| """Detailed health check.""" | |
| return { | |
| "status": "healthy" if engine else "initializing", | |
| "device": CONFIG['device'], | |
| "model_loaded": engine is not None, | |
| "corpus_loaded": len(engine.corpus) if engine else 0, | |
| } | |
| async def get_recommendations(request: RecommendationRequest): | |
| """ | |
| Get paper/video recommendations for a query. | |
| **Parameters:** | |
| - query: Search query (required) | |
| - top_k: Number of results (1-50, default 10) | |
| - filter_type: Filter by 'paper' or 'video' | |
| - year_min: Minimum publication year | |
| - year_max: Maximum publication year | |
| - category: Filter by arXiv category | |
| - min_citations: Minimum citation count | |
| **Returns:** | |
| - Ranked list of recommendations with scores and metadata | |
| """ | |
| if not engine: | |
| raise HTTPException(status_code=503, detail="Engine not initialized") | |
| try: | |
| import time | |
| start_time = time.time() | |
| # Get recommendations | |
| results = engine.recommend( | |
| query=request.query, | |
| top_k=request.top_k, | |
| filter_type=request.filter_type, | |
| year_min=request.year_min, | |
| year_max=request.year_max, | |
| category=request.category, | |
| min_citations=request.min_citations, | |
| ) | |
| # Calculate execution time | |
| execution_time = (time.time() - start_time) * 1000 # Convert to ms | |
| # Format response | |
| response = RecommendationResponse( | |
| query=request.query, | |
| total_results=len(results), | |
| recommendations=[ | |
| RecommendationItem(**item) for item in results | |
| ], | |
| execution_time_ms=round(execution_time, 2) | |
| ) | |
| return response | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Recommendation failed: {str(e)}") | |
| async def get_stats(): | |
| """Get corpus statistics.""" | |
| if not engine: | |
| raise HTTPException(status_code=503, detail="Engine not initialized") | |
| papers = [item for item in engine.corpus if item.get('type') == 'paper'] | |
| videos = [item for item in engine.corpus if item.get('type') == 'video'] | |
| # Category distribution | |
| categories = {} | |
| for paper in papers: | |
| metadata = paper.get('metadata', {}) | |
| cat = metadata.get('primary_category', '') or metadata.get('category', 'unknown') | |
| categories[cat] = categories.get(cat, 0) + 1 | |
| top_categories = sorted(categories.items(), key=lambda x: x[1], reverse=True)[:10] | |
| return { | |
| "total_items": len(engine.corpus), | |
| "papers": len(papers), | |
| "videos": len(videos), | |
| "top_categories": [{"category": cat, "count": count} for cat, count in top_categories], | |
| "corpus_metadata": engine.corpus_metadata, | |
| } | |
| # ============================================================================ | |
| # MAIN | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| print("\nπ Starting API server...") | |
| print("π API docs will be available at: http://localhost:8000/docs") | |
| print("π§ Health check: http://localhost:8000/health\n") | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=8000, | |
| log_level="info" | |
| ) | |