Spaces:
Build error
Build error
| import logging | |
| import gradio as gr | |
| from queue import Queue | |
| import time | |
| from prometheus_client import start_http_server, Counter, Histogram, Gauge | |
| import threading | |
| import psutil | |
| import random | |
| from transformers import pipeline | |
| from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score | |
| import requests | |
| from datasets import load_dataset | |
| import os | |
| from logging import FileHandler | |
| from typing import Iterable | |
| # Ensure the log files exist | |
| log_file_path = 'chat_log.log' | |
| debug_log_file_path = 'debug.log' | |
| if not os.path.exists(log_file_path): | |
| with open(log_file_path, 'w') as f: | |
| f.write(" ") | |
| if not os.path.exists(debug_log_file_path): | |
| with open(debug_log_file_path, 'w') as f: | |
| f.write(" ") | |
| # Create logger instance | |
| logger = logging.getLogger() | |
| logger.setLevel(logging.DEBUG) # Set logger level to the lowest level needed | |
| #Create formatter | |
| formatter = logging.Formatter( | |
| '%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S') | |
| # Create handlers | |
| info_handler = FileHandler( filename=log_file_path, mode='w+') | |
| info_handler.setLevel(logging.INFO) | |
| info_handler.setFormatter(formatter) | |
| debug_handler = FileHandler(filename=debug_log_file_path, mode='w+') | |
| debug_handler.setLevel(logging.DEBUG) | |
| debug_handler.setFormatter(formatter) | |
| # Function to capture logs for Gradio display | |
| class GradioHandler(logging.Handler): | |
| def __init__(self, logs_queue): | |
| super().__init__() | |
| self.logs_queue = logs_queue | |
| def emit(self, record): | |
| log_entry = self.format(record) | |
| self.logs_queue.put(log_entry) | |
| # Create a logs queue | |
| logs_queue = Queue() | |
| # Create and configure Gradio handler | |
| gradio_handler = GradioHandler(logs_queue) | |
| gradio_handler.setLevel(logging.INFO) | |
| gradio_handler.setFormatter(formatter) | |
| # Add handlers to the logger | |
| logger.addHandler(info_handler) | |
| logger.addHandler(debug_handler) | |
| logger.addHandler(gradio_handler) | |
| # Load the model | |
| try: | |
| ner_pipeline = pipeline("ner", model="Sevixdd/roberta-base-finetuned-ner") | |
| logger.debug("NER pipeline loaded.") | |
| except Exception as e: | |
| logger.debug(f"Error loading NER pipeline: {e}") | |
| # Load the dataset | |
| try: | |
| dataset = load_dataset("surrey-nlp/PLOD-filtered") | |
| logger.debug("Dataset loaded.") | |
| except Exception as e: | |
| logger.debug(f"Error loading dataset: {e}") | |
| # --- Prometheus Metrics Setup --- | |
| try: | |
| REQUEST_COUNT = Counter('gradio_request_count', 'Total number of requests') | |
| REQUEST_LATENCY = Histogram('gradio_request_latency_seconds', 'Request latency in seconds') | |
| ERROR_COUNT = Counter('gradio_error_count', 'Total number of errors') | |
| RESPONSE_SIZE = Histogram('gradio_response_size_bytes', 'Size of responses in bytes') | |
| CPU_USAGE = Gauge('system_cpu_usage_percent', 'System CPU usage in percent') | |
| MEM_USAGE = Gauge('system_memory_usage_percent', 'System memory usage in percent') | |
| QUEUE_LENGTH = Gauge('chat_queue_length', 'Length of the chat queue') | |
| logger.debug("Prometheus metrics setup complete.") | |
| except Exception as e: | |
| logger.debug(f"Error setting up Prometheus metrics: {e}") | |
| # --- Queue and Metrics --- | |
| chat_queue = Queue() # Define chat_queue globally | |
| label_mapping = { | |
| 0: 'B-O', | |
| 1: 'B-AC', | |
| 3: 'B-LF', | |
| 4: 'I-LF' | |
| } | |
| def classification(message): | |
| # Predict using the model | |
| ner_results = ner_pipeline(" ".join(message)) | |
| detailed_response = [] | |
| model_predicted_labels = [] | |
| for result in ner_results: | |
| token = result['word'] | |
| score = result['score'] | |
| entity = result['entity'] | |
| label_id = int(entity.split('_')[-1]) # Extract numeric label from entity | |
| model_predicted_labels.append(label_mapping[label_id]) | |
| detailed_response.append(f"Token: {token}, Entity: {label_mapping[label_id]}, Score: {score:.4f}") | |
| response = "\n".join(detailed_response) | |
| response_size = len(response.encode('utf-8')) | |
| RESPONSE_SIZE.observe(response_size) | |
| time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time | |
| return response, model_predicted_labels | |
| # --- Chat Function with Monitoring --- | |
| def chat_function(input, datasets): | |
| logger.debug("Starting chat_function") | |
| with REQUEST_LATENCY.time(): | |
| REQUEST_COUNT.inc() | |
| try: | |
| if input.isnumeric(): | |
| chat_queue.put(input) | |
| # Get the example from the dataset | |
| if datasets: | |
| example = datasets[int(input)] | |
| else: | |
| example = dataset['train'][int(input)] | |
| tokens = example['tokens'] | |
| ground_truth_labels = [label_mapping[label] for label in example['ner_tags']] | |
| # Call the classification function | |
| response, model_predicted_labels = classification(tokens) | |
| # Ensure the model and ground truth labels are the same length for comparison | |
| model_predicted_labels = model_predicted_labels[:len(ground_truth_labels)] | |
| precision = precision_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0) | |
| recall = recall_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0) | |
| f1 = f1_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0) | |
| accuracy = accuracy_score(ground_truth_labels, model_predicted_labels) | |
| metrics_response = (f"Precision: {precision:.4f}\n" | |
| f"Recall: {recall:.4f}\n" | |
| f"F1 Score: {f1:.4f}\n" | |
| f"Accuracy: {accuracy:.4f}") | |
| full_response = f"**Record**:\nTokens: {tokens}\nGround Truth Labels: {ground_truth_labels}\n\n**Predictions**:\n{response}\n\n**Metrics**:\n{metrics_response}" | |
| logger.info(f"\nInput details: \n Received index from user: {input} Sending response to user: {full_response}") | |
| else: | |
| chat_queue.put(input) | |
| response, predicted_labels = classification([input]) | |
| full_response = f"Input details: \n**Input Sentence:** {input}\n\n**Predictions**:\n{response}\n\n" | |
| logger.info(full_response) | |
| chat_queue.get() | |
| return full_response | |
| except Exception as e: | |
| ERROR_COUNT.inc() | |
| logger.error(f"Error in chat processing: {e}", exc_info=True) | |
| return f"An error occurred. Please try again. Error: {e}" | |
| # Function to simulate stress test | |
| def stress_test(num_requests, message, delay): | |
| def send_chat_message(): | |
| try: | |
| response = requests.post("http://127.0.0.1:7860/api/predict/", json={ | |
| "data": [message], | |
| "fn_index": 0 # This might need to be updated based on your Gradio app's function index | |
| }) | |
| logger.debug(f"Request payload: {message}",exc_info=True) | |
| logger.debug(f"Response: {response.json()}",exc_info=True) | |
| except Exception as e: | |
| logger.debug(f"Error during stress test request: {e}", exc_info=True) | |
| threads = [] | |
| for _ in range(num_requests): | |
| t = threading.Thread(target=send_chat_message) | |
| t.start() | |
| threads.append(t) | |
| time.sleep(delay) # Delay between requests | |
| for t in threads: | |
| t.join() | |
| # --- Gradio Interface with Background Image and Three Windows --- | |
| with gr.Blocks(title="PLOD Filtered with Monitoring") as demo: # Load CSS for background image | |
| with gr.Tab("Sentence input"): | |
| gr.Markdown("## Chat with the Bot") | |
| index_input = gr.Textbox(label="Enter A sentence:", lines=1) | |
| output = gr.Markdown(label="Response") | |
| chat_interface = gr.Interface(fn=chat_function, inputs=[index_input], outputs=output) | |
| with gr.Tab("Dataset and Index Input"): | |
| gr.Markdown("## Chat with the Bot") | |
| interface = gr.Interface(fn = chat_function, | |
| inputs=[gr.Textbox(label="Enter dataset index:", lines=1), gr.UploadButton(label ="Upload Dataset", file_types=[".csv", ".tsv"])], | |
| outputs = gr.Markdown(label="Response")) | |
| with gr.Tab("Model Parameters"): | |
| model_params_display = gr.Textbox(label="Model Parameters", lines=20, interactive=False) # Display model parameters | |
| with gr.Tab("Performance Metrics"): | |
| request_count_display = gr.Number(label="Request Count", value=0) | |
| avg_latency_display = gr.Number(label="Avg. Response Time (s)", value=0) | |
| with gr.Tab("Infrastructure"): | |
| cpu_usage_display = gr.Number(label="CPU Usage (%)", value=0) | |
| mem_usage_display = gr.Number(label="Memory Usage (%)", value=0) | |
| with gr.Tab("Logs"): | |
| logs_display = gr.Textbox(label="Logs", lines=10) # Increased lines for better visibility | |
| with gr.Tab("Stress Testing"): | |
| num_requests_input = gr.Number(label="Number of Requests", value=10) | |
| index_input_stress = gr.Textbox(label="Dataset Index", value="2") | |
| delay_input = gr.Number(label="Delay Between Requests (seconds)", value=0.1) | |
| stress_test_button = gr.Button("Start Stress Test") | |
| stress_test_status = gr.Textbox(label="Stress Test Status", lines=5, interactive=False) | |
| def run_stress_test(num_requests, index, delay): | |
| stress_test_status.value = "Stress test started..." | |
| try: | |
| stress_test(num_requests, index, delay) | |
| stress_test_status.value = "Stress test completed." | |
| except Exception as e: | |
| stress_test_status.value = f"Stress test failed: {e}" | |
| stress_test_button.click(run_stress_test, [num_requests_input, index_input_stress, delay_input], stress_test_status) | |
| img = gr.Image( | |
| "stag.jpeg", label="Image" | |
| ) | |
| # --- Update Functions --- | |
| def update_metrics(request_count_display, avg_latency_display): | |
| while True: | |
| request_count = REQUEST_COUNT._value.get() | |
| latency_samples = REQUEST_LATENCY.collect()[0].samples | |
| avg_latency = sum(s.value for s in latency_samples) / len(latency_samples if latency_samples else [1]) # Avoid division by zero | |
| request_count_display.value = request_count | |
| avg_latency_display.value = round(avg_latency, 2) | |
| time.sleep(5) # Update every 5 seconds | |
| def update_usage(cpu_usage_display, mem_usage_display): | |
| while True: | |
| cpu_usage_display.value = psutil.cpu_percent() | |
| mem_usage_display.value = psutil.virtual_memory().percent | |
| CPU_USAGE.set(psutil.cpu_percent()) | |
| MEM_USAGE.set(psutil.virtual_memory().percent) | |
| time.sleep(5) | |
| def update_logs(logs_display): | |
| while True: | |
| info_log_vector = [] | |
| logs = [] | |
| while not logs_queue.empty(): | |
| logs.append(logs_queue.get()) | |
| logs_display.value = "\n".join(logs[-10:]) | |
| time.sleep(1) # Update every 1 second | |
| def display_model_params(model_params_display): | |
| while True: | |
| model_params = ner_pipeline.model.config.to_dict() | |
| model_params_str = "\n".join(f"{key}: {value}" for key, value in model_params.items()) | |
| model_params_display.value = model_params_str | |
| time.sleep(10) # Update every 10 seconds | |
| def update_queue_length(): | |
| while True: | |
| QUEUE_LENGTH.set(chat_queue.qsize()) | |
| time.sleep(1) # Update every second | |
| # --- Start Threads --- | |
| threading.Thread(target=start_http_server, args=(8000,), daemon=True).start() | |
| threading.Thread(target=update_metrics, args=(request_count_display, avg_latency_display), daemon=True).start() | |
| threading.Thread(target=update_usage, args=(cpu_usage_display, mem_usage_display), daemon=True).start() | |
| threading.Thread(target=update_logs, args=(logs_display), daemon=True).start() | |
| threading.Thread(target=display_model_params, args=(model_params_display,), daemon=True).start() | |
| threading.Thread(target=update_queue_length, daemon=True).start() | |
| # Launch the app | |
| demo.launch(share=True) |