self-trained / main.py
DeepImagix's picture
Create main.py
35d9ab2 verified
Raw
History Blame Contribute Delete
3.04 kB
import os
import joblib
import pandas as pd
from fastapi import FastAPI, Form, HTTPException, Request
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import Pipeline
from sklearn.exceptions import NotFittedError
# --- 1. Basic Setup & Configuration ---
# Hugging Face provides persistent storage at the root of the project
USER_MODELS_DIR = "user_models_data"
os.makedirs(USER_MODELS_DIR, exist_ok=True)
app = FastAPI()
# --- Label Mapping ---
label_mapping = {
'positive_sentiment': 0, 'negative_sentiment': 1,
'greeting': 2, 'farewell': 3, 'thanks': 4, 'searching_inquiry': 5
}
reverse_label_mapping = {v: k for k, v in label_mapping.items()}
# --- 2. Helper Functions ---
def get_user_paths(user_id: str):
user_dir = os.path.join(USER_MODELS_DIR, user_id)
os.makedirs(user_dir, exist_ok=True)
return {
"model_path": os.path.join(user_dir, "model.joblib"),
"data_path": os.path.join(user_dir, "training_data.csv")
}
# --- 3. API Endpoints ---
@app.get("/")
def read_root():
return {"message": "Welcome! Your AI is running on Hugging Face Spaces."}
@app.post("/predict/")
async def predict(user_id: str = Form(...), text: str = Form(...)):
paths = get_user_paths(user_id)
if not os.path.exists(paths["model_path"]):
raise HTTPException(status_code=404, detail="Model not found. Please train it first.")
model_pipeline = joblib.load(paths["model_path"])
try:
predicted_index = model_pipeline.predict([text])[0]
probabilities = model_pipeline.predict_proba([text])[0]
predicted_label = reverse_label_mapping[predicted_index]
confidence = float(probabilities[predicted_index])
return {"intent": predicted_label, "confidence": confidence}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/train/")
async def train(user_id: str = Form(...), text: str = Form(...), label: str = Form(...)):
if label not in label_mapping:
raise HTTPException(status_code=400, detail="Invalid label.")
paths = get_user_paths(user_id)
new_data = pd.DataFrame([{"text": text, "label": label}])
if os.path.exists(paths["data_path"]):
new_data.to_csv(paths["data_path"], mode='a', header=False, index=False)
else:
new_data.to_csv(paths["data_path"], mode='w', header=True, index=False)
df = pd.read_csv(paths["data_path"])
if len(df['label'].unique()) < 2:
return {"message": "Model not trained. Please provide at least two different categories of examples."}
df['label_numeric'] = df['label'].map(label_mapping)
X = df['text']
y = df['label_numeric']
model_pipeline = Pipeline([
('tfidf', TfidfVectorizer()),
('clf', SGDClassifier(loss='modified_huber', random_state=42)),
])
model_pipeline.fit(X, y)
joblib.dump(model_pipeline, paths["model_path"])
return {"message": f"Training successful for user '{user_id}'."}