First_rec / app.py
Noor22Tak's picture
Update app.py
0ec3c01 verified
import base64
import traceback
import faiss
from fastapi import FastAPI, HTTPException
import requests
from pydantic import BaseModel
import numpy as np
import pandas as pd
import os
# Initialize FastAPI app
app = FastAPI()
HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY") # Load from environment variable
# Hugging Face API details
API_URL = "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/all-MiniLM-L6-v2"
HEADERS = {
"Authorization": f"Bearer {HUGGINGFACE_API_KEY}",
"Content-Type": "application/json; charset=UTF-8",
}
# Store embeddings globally (in-memory storage)
global_embedding = None
index = faiss.read_index("news_index.faissFF")
@app.get('/')
def home():
return {"Message": "Hello"}
# ReCreate the index file with 384 embedding -------------------------------------------------
# Define the correct dimension
# embedding_dim = 384
# # Create a new FAISS index with L2 distance
# new_index = faiss.IndexFlatL2(embedding_dim)
# # Extract only the first 384 dimensions from the old 4096D vectors
# stored_vectors = index.reconstruct_n(0, index.ntotal) # Get all stored vectors
# stored_vectors_384 = stored_vectors[:, :embedding_dim] # Keep only first 384D
# # Add them to the new FAISS index
# new_index.add(stored_vectors_384)
# faiss.write_index(new_index, "faiss_index_384D.index")
# -----------------------------------------------
# Request model for input validation
class EmbeddingRequest(BaseModel):
text: str
# Function to get embedding from Hugging Face API
def get_embedding(text: str):
try:
response = requests.post(API_URL, headers=HEADERS, json={"inputs": text})
if response.status_code != 200:
raise HTTPException(status_code=response.status_code, detail=response.json())
return response.json()
except requests.RequestException as e:
raise HTTPException(status_code=500, detail=str(e))
print(f"FAISS index size: {index.ntotal}") # Total stored vectors
news_df = pd.read_csv("news_dataset.csv") # Ensure this file is in the correct directory
@app.post("/get_Emd_Corrected")
async def generate_embedding(request: EmbeddingRequest):
try:
embedding = np.array(get_embedding(request.text), dtype="float32")
if embedding.shape[0] != 384:
return {"error": f"Expected embedding of size 384, got {embedding.shape[0]}"}
embedding_query = embedding.reshape(1, -1) # Keep it 384D
if index is None:
return {"error": "FAISS index not loaded"}
k = 10
distances, indices = index.search(embedding_query, k)
# Retrieve news articles based on indices
results = []
for i, idx in enumerate(indices[0]): # Iterate over retrieved indices
if idx < len(news_df): # Ensure index is within bounds
article = news_df.iloc[idx].to_dict()
article["distance"] = float(distances[0][i]) # Add similarity score
results.append(article)
return {
"embedding": embedding.tolist(),
"Distances": distances.tolist(),
"Indices": indices.tolist(),
"results": results
}
except Exception as e:
return {"error": str(e), "traceback": traceback.format_exc()}
import re
def clean_arabic_text(text):
"""Removes invalid characters that cause JSON decoding errors"""
text = re.sub(r"[\x00-\x1F\x7F\u202c\ufeff]", "", text) # Remove hidden control characters
return text.strip()
@app.post("/get_Emd_Data")
async def generate_embedding(request: EmbeddingRequest):
try:
request.text = clean_arabic_text(request.text)
encoded_text = base64.b64encode(request.text.encode()).decode() # Encode text in Base64
# Get the embedding
embedding = np.array(get_embedding(encoded_text), dtype="float32")
if embedding.shape[0] != 384:
return {"error": f"Expected embedding of size 384, got {embedding.shape[0]}"}
# Ensure it's 384D
embedding_query = embedding.reshape(1, -1)
# Check if FAISS index is loaded
if index is None:
return {"error": "FAISS index not loaded"}
# Search FAISS index
k = 10 # Number of nearest neighbors
distances, indices = index.search(embedding_query, k)
# Retrieve news articles based on indices
results = []
for i, idx in enumerate(indices[0]): # Iterate over retrieved indices
if idx < len(news_df): # Ensure index is within bounds
article = news_df.iloc[idx].to_dict()
article["distance"] = float(distances[0][i]) # Add similarity score
results.append(article)
return {"results": results}
except Exception as e:
return {"error": str(e), "traceback": traceback.format_exc()}
# FastAPI endpoint to retrieve the last stored embedding
@app.get("/last-embedding")
async def get_last_embedding():
if global_embedding is None:
raise HTTPException(status_code=404, detail="No embedding stored yet")
return {"last_embedding": global_embedding}