Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| repo_id = "iimran/AnalyserV2" | |
| def download_model_files(repo_id): | |
| model_path = hf_hub_download(repo_id=repo_id, filename="model_weights.pth") | |
| vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.json") | |
| label_encoder_path = hf_hub_download(repo_id=repo_id, filename="label_encoder.json") | |
| config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
| return model_path, vocab_path, label_encoder_path, config_path | |
| def get_transformer_model_class(): | |
| model_code = os.getenv("MODEL_VAR") | |
| if model_code is None: | |
| raise ValueError("Environment variable 'MODEL_VAR' is not set.") | |
| exec(model_code, globals()) | |
| if "TransformerModel" not in globals(): | |
| raise NameError("The TransformerModel class was not defined after executing MODEL_VAR.") | |
| TransformerModel = globals()["TransformerModel"] | |
| #print("TransformerModel Class:", TransformerModel) | |
| return TransformerModel | |
| def get_preprocess_function(): | |
| # Retrieve the preprocess_text code from the environment variable | |
| preprocess_code = os.getenv("MODEL_PROCESS") | |
| if preprocess_code is None: | |
| raise ValueError("Environment variable 'MODEL_PROCESS' is not set.") | |
| exec(preprocess_code, globals()) | |
| if "preprocess_text" not in globals(): | |
| raise NameError("The preprocess_text function was not defined after executing MODEL_PROCESS.") | |
| #print("Preprocess Function Loaded:", globals()["preprocess_text"]) | |
| return globals()["preprocess_text"] | |
| def load_model_and_resources(repo_id): | |
| model_path, vocab_path, label_encoder_path, config_path = download_model_files(repo_id) | |
| try: | |
| with open(vocab_path, "r") as f: | |
| vocab = json.load(f) | |
| except FileNotFoundError: | |
| raise FileNotFoundError(f"Vocabulary file not found at {vocab_path}. Please check the repository.") | |
| except json.JSONDecodeError: | |
| raise ValueError(f"Invalid JSON format in vocabulary file at {vocab_path}.") | |
| try: | |
| with open(label_encoder_path, "r") as f: | |
| label_encoder_classes = json.load(f) | |
| except FileNotFoundError: | |
| raise FileNotFoundError(f"Label encoder file not found at {label_encoder_path}. Please check the repository.") | |
| except json.JSONDecodeError: | |
| raise ValueError(f"Invalid JSON format in label encoder file at {label_encoder_path}.") | |
| TransformerModel = get_transformer_model_class() | |
| model = TransformerModel(vocab_size=len(vocab), num_classes=len(label_encoder_classes)) | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Use "cuda" if GPU is available | |
| model.eval() | |
| #print("Model Architecture:") | |
| #print(model) | |
| return model, vocab, label_encoder_classes | |
| preprocess_text = get_preprocess_function() | |
| def predict(text, model, vocab, label_encoder_classes): | |
| input_ids, attention_mask = preprocess_text(text, vocab) | |
| print("Input IDs:", input_ids) | |
| print("Attention Mask:", attention_mask) | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask) | |
| print("Model Outputs:", outputs) # Debug: Inspect model outputs | |
| if outputs is None: | |
| raise ValueError("Model returned None. Check the forward method and input data.") | |
| predicted_class_idx = outputs.argmax(1).item() | |
| predicted_label = label_encoder_classes[predicted_class_idx] | |
| return predicted_label | |
| def create_gradio_interface(): | |
| model, vocab, label_encoder_classes = load_model_and_resources(repo_id) | |
| def predict_wrapper(text): | |
| return predict(text, model, vocab, label_encoder_classes) | |
| interface = gr.Interface( | |
| fn=predict_wrapper, # Use the wrapper function | |
| inputs=gr.Textbox(lines=2, placeholder="Enter text here..."), | |
| outputs=gr.Textbox(label="Predicted Label"), | |
| title="Text Classification Model", | |
| description="Enter text to classify it using the model.", | |
| examples=[ | |
| ["I would like to bring to your attention a pothole on Main Street that has become a safety hazard. The pothole is quite deep and poses a risk to both drivers and pedestrians. I kindly request the council to inspect and repair it at the earliest to prevent any potential accidents or vehicle damage. Please let me know if any further information is required."], | |
| ["I am writing to report a clogged drainage system in 1 tonsley. The blockage is causing water to accumulate, leading to potential flooding and sanitation issues. This situation poses a risk to public health and safety, especially during rainfall. I kindly request the council to inspect and resolve this issue at the earliest convenience."], | |
| ["I am writing to report a persistent issue of loud noise coming from my neighbors at 1 tonsley. The noise, which occurs through out the day, has been causing significant disturbance to me and other residents in the area."] | |
| ], | |
| cache_examples=False # Disable caching | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| interface = create_gradio_interface() | |
| interface.launch() |