Spaces:
Runtime error
Runtime error
| import pickle | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from fastapi import HTTPException | |
| from schemas.input_schemas import CosineSimilarityResponse, EmbeddingResponse | |
| # Load the trained model and vectorizer | |
| def load_model(): | |
| model_path = "models/sms_classifier_model.pkl" | |
| vectorizer_path = "models/tfidf_vectorizer.pkl" | |
| try: | |
| with open(model_path, 'rb') as f: | |
| classifier = pickle.load(f) | |
| with open(vectorizer_path, 'rb') as f: | |
| vectorizer = pickle.load(f) | |
| return classifier, vectorizer | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error loading model: {str(e)}") | |
| async def predict_label(message: str): | |
| try: | |
| classifier, vectorizer = load_model() | |
| # Vectorize the input message | |
| message_vec = vectorizer.transform([message]) | |
| # Predict the label | |
| label = classifier.predict(message_vec)[0] | |
| return {"label": label} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error predicting label: {str(e)}") | |
| async def compute_cosine_similarity(text1: str, text2: str): | |
| try: | |
| classifier, vectorizer = load_model() | |
| # Vectorize the input texts | |
| vec1 = vectorizer.transform([text1]).toarray() | |
| vec2 = vectorizer.transform([text2]).toarray() | |
| # Compute cosine similarity | |
| cosine_sim = np.dot(vec1, vec2.T) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) | |
| return CosineSimilarityResponse(cosine_similarity=cosine_sim[0][0]) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error computing similarity: {str(e)}") | |
| async def get_embedding(message: str): | |
| try: | |
| classifier, vectorizer = load_model() | |
| # Vectorize the input message | |
| embedding = vectorizer.transform([message]).toarray().tolist() | |
| return EmbeddingResponse(embeddings=embedding) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error computing embeddings: {str(e)}") | |