Spaces:
No application file
No application file
| import os | |
| import warnings | |
| import logging | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| # Suppress TensorFlow warnings | |
| warnings.filterwarnings('ignore', category=UserWarning) | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logging | |
| logging.getLogger('tensorflow').setLevel(logging.ERROR) | |
| from pydantic import BaseModel | |
| import pickle | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from backend.models.schemas import RecommendationRequest, ProductRecommendation, RecommendationResponse | |
| # Initialize FastAPI app with metadata | |
| app = FastAPI( | |
| title="Product Recommendation API", | |
| description="API for getting product recommendations based on user queries", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load embeddings from the correct path | |
| models_dir = os.path.join(os.path.dirname(__file__), "models") | |
| embeddings_path = os.path.join(models_dir, "product_embeddings.pkl") | |
| # Ensure models directory exists | |
| os.makedirs(models_dir, exist_ok=True) | |
| try: | |
| # Load embeddings | |
| if os.path.exists(embeddings_path): | |
| with open(embeddings_path, "rb") as f: | |
| data = pickle.load(f) | |
| else: | |
| raise FileNotFoundError("Product embeddings file not found") | |
| # Load model | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| except Exception as e: | |
| print(f"Error initializing server: {e}") | |
| raise | |
| def read_root(): | |
| return { | |
| "message": "Product Recommendation API", | |
| "status": "active", | |
| "total_products": len(data["product_ids"]) if "product_ids" in data else 0 | |
| } | |
| def recommend_products(request: RecommendationRequest): | |
| """Get product recommendations based on user search""" | |
| try: | |
| query_embedding = model.encode(request.query).reshape(1, -1) | |
| # Compute similarity scores | |
| similarities = cosine_similarity(query_embedding, data["embeddings"]) | |
| top_indices = np.argsort(similarities[0])[-request.top_n:][::-1] | |
| recommendations = [] | |
| for i in top_indices: | |
| recommendations.append(ProductRecommendation( | |
| product_id=data["product_ids"][i], | |
| product_name=data["product_names"][i], | |
| description=data["descriptions"][i], | |
| brand=data["brands"][i], | |
| price=float(data["prices"][i]), | |
| categories=eval(data["categories"][i]) if isinstance(data["categories"][i], str) else data["categories"][i], | |
| score=float(similarities[0][i]) | |
| )) | |
| return RecommendationResponse(recommendations=recommendations) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| try: | |
| uvicorn.run( | |
| "backend.server:app", | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=True, | |
| access_log=True | |
| ) | |
| except Exception as e: | |
| print(f"Error starting server: {e}") | |