| import torch | |
| import io | |
| import os | |
| from urdu_turn_detection import UrduTurnDetector | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| # path is the directory containing the model weights and this script | |
| self.detector = UrduTurnDetector.from_pretrained(path) | |
| def __call__(self, data): | |
| """ | |
| Args: | |
| data (:obj:`dict`): | |
| subset of the input dictionary. | |
| Return: | |
| A :obj:`dict`: will be serialized to JSON. | |
| """ | |
| inputs = data.get("inputs") | |
| if inputs is None: | |
| return {"error": "Missing 'inputs' key in request data"} | |
| try: | |
| # Decode audio from bytes | |
| if isinstance(inputs, str): | |
| # If path or URL | |
| audio_data = inputs | |
| else: | |
| # Assuming bytes | |
| audio_data = io.BytesIO(inputs) | |
| # Use the library's abstraction | |
| result = self.detector.predict(audio_data) | |
| return { | |
| "prediction": result.label, | |
| "confidence": result.confidence, | |
| "status": "success" | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "status": "failed"} | |
| # End of handler.py | |