Spaces:
Runtime error
Runtime error
| import os | |
| import joblib | |
| import pandas as pd | |
| from fastapi import FastAPI, Form, HTTPException, Request | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.linear_model import SGDClassifier | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.exceptions import NotFittedError | |
| # --- 1. Basic Setup & Configuration --- | |
| # Hugging Face provides persistent storage at the root of the project | |
| USER_MODELS_DIR = "user_models_data" | |
| os.makedirs(USER_MODELS_DIR, exist_ok=True) | |
| app = FastAPI() | |
| # --- Label Mapping --- | |
| label_mapping = { | |
| 'positive_sentiment': 0, 'negative_sentiment': 1, | |
| 'greeting': 2, 'farewell': 3, 'thanks': 4, 'searching_inquiry': 5 | |
| } | |
| reverse_label_mapping = {v: k for k, v in label_mapping.items()} | |
| # --- 2. Helper Functions --- | |
| def get_user_paths(user_id: str): | |
| user_dir = os.path.join(USER_MODELS_DIR, user_id) | |
| os.makedirs(user_dir, exist_ok=True) | |
| return { | |
| "model_path": os.path.join(user_dir, "model.joblib"), | |
| "data_path": os.path.join(user_dir, "training_data.csv") | |
| } | |
| # --- 3. API Endpoints --- | |
| def read_root(): | |
| return {"message": "Welcome! Your AI is running on Hugging Face Spaces."} | |
| async def predict(user_id: str = Form(...), text: str = Form(...)): | |
| paths = get_user_paths(user_id) | |
| if not os.path.exists(paths["model_path"]): | |
| raise HTTPException(status_code=404, detail="Model not found. Please train it first.") | |
| model_pipeline = joblib.load(paths["model_path"]) | |
| try: | |
| predicted_index = model_pipeline.predict([text])[0] | |
| probabilities = model_pipeline.predict_proba([text])[0] | |
| predicted_label = reverse_label_mapping[predicted_index] | |
| confidence = float(probabilities[predicted_index]) | |
| return {"intent": predicted_label, "confidence": confidence} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def train(user_id: str = Form(...), text: str = Form(...), label: str = Form(...)): | |
| if label not in label_mapping: | |
| raise HTTPException(status_code=400, detail="Invalid label.") | |
| paths = get_user_paths(user_id) | |
| new_data = pd.DataFrame([{"text": text, "label": label}]) | |
| if os.path.exists(paths["data_path"]): | |
| new_data.to_csv(paths["data_path"], mode='a', header=False, index=False) | |
| else: | |
| new_data.to_csv(paths["data_path"], mode='w', header=True, index=False) | |
| df = pd.read_csv(paths["data_path"]) | |
| if len(df['label'].unique()) < 2: | |
| return {"message": "Model not trained. Please provide at least two different categories of examples."} | |
| df['label_numeric'] = df['label'].map(label_mapping) | |
| X = df['text'] | |
| y = df['label_numeric'] | |
| model_pipeline = Pipeline([ | |
| ('tfidf', TfidfVectorizer()), | |
| ('clf', SGDClassifier(loss='modified_huber', random_state=42)), | |
| ]) | |
| model_pipeline.fit(X, y) | |
| joblib.dump(model_pipeline, paths["model_path"]) | |
| return {"message": f"Training successful for user '{user_id}'."} |