Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import joblib | |
| # Define the class names | |
| class_names = [ | |
| 'Family Issues', | |
| 'Relationship Conflicts', | |
| 'Work Dynamics', | |
| 'Financial and Legal Disagreements', | |
| 'Personal Boundaries', | |
| 'Cultural and Identity-Based Issues', | |
| 'Other' | |
| ] | |
| # Define the custom pipeline | |
| class CustomSVMTextClassificationPipeline: | |
| def __init__(self, model_path, vectorizer_path): | |
| # Load the model and vectorizer | |
| self.model = joblib.load(model_path) | |
| self.vectorizer = joblib.load(vectorizer_path) | |
| def __call__(self, texts): | |
| if isinstance(texts, str): | |
| texts = [texts] # Ensure input is a list | |
| # Preprocess input using the vectorizer | |
| preprocessed_texts = self.vectorizer.transform(texts) | |
| # Predict using the model | |
| predictions = self.model.predict(preprocessed_texts) | |
| # Convert predictions into readable format (class names) | |
| results = [] | |
| for pred in predictions: | |
| predicted_classes = [class_names[i] for i, value in enumerate(pred) if value == 1] | |
| results.append(predicted_classes) | |
| return results if len(results) > 1 else results[0] # Return a single result for single input | |
| # Load the model and vectorizer | |
| model_path = "svm_multi_output_model.pkl" # Replace with your model file path | |
| vectorizer_path = "tfidf_vectorizer.pkl" # Replace with your vectorizer file path | |
| classifier = CustomSVMTextClassificationPipeline(model_path, vectorizer_path) | |
| def classify_text(input_text): | |
| """ | |
| Classify the input text using the custom pipeline. | |
| """ | |
| results = classifier(input_text) | |
| return results | |
| # Create the Gradio interface | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Text Classification App") | |
| gr.Markdown("Enter text to classify:") | |
| input_text = gr.Textbox(label="Input Text") | |
| output = gr.JSON(label="Classification Results") | |
| submit_button = gr.Button("Classify") | |
| submit_button.click(classify_text, inputs=[input_text], outputs=[output]) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app.launch() | |