Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import transformers | |
| from pyabsa import AspectTermExtraction as ATEPC | |
| import warnings | |
| # 1. Compatibility setup for Hugging Face Spaces | |
| warnings.filterwarnings("ignore") | |
| transformers.PretrainedConfig.is_decoder = False | |
| transformers.PretrainedConfig.output_attentions = False | |
| transformers.PretrainedConfig.output_hidden_states = False | |
| from huggingface_hub import snapshot_download | |
| import shutil | |
| import transformers | |
| print(f"Transformers version: {transformers.__version__}") | |
| def find_checkpoint_path(root): | |
| """Recursively search for the pyabsa checkpoint directory contains .config file""" | |
| for r, d, f in os.walk(root): | |
| for file in f: | |
| if file.endswith('.config') and 'fast_lcf_atepc' in file: | |
| return r | |
| return None | |
| # 2. Path to the model on Hugging Face Hub | |
| REPO_ID = "siwarroth/deberta-absa" | |
| LOCAL_MODEL_DIR = "model_files" | |
| print(f"Downloading model from Hugging Face Hub: {REPO_ID}...") | |
| try: | |
| # Ensure local directory exists | |
| if os.path.exists(LOCAL_MODEL_DIR): | |
| shutil.rmtree(LOCAL_MODEL_DIR) | |
| os.makedirs(LOCAL_MODEL_DIR) | |
| local_repo_path = snapshot_download(repo_id=REPO_ID) | |
| print(f"Full snapshot downloaded to: {local_repo_path}") | |
| source_checkpoint = find_checkpoint_path(local_repo_path) | |
| if source_checkpoint: | |
| print(f"Found valid checkpoint at: {source_checkpoint}. Copying to {LOCAL_MODEL_DIR}...") | |
| # Copy all files from the found checkpoint to our local directory | |
| for item in os.listdir(source_checkpoint): | |
| s = os.path.join(source_checkpoint, item) | |
| d = os.path.join(LOCAL_MODEL_DIR, item) | |
| if os.path.isdir(s): | |
| shutil.copytree(s, d) | |
| else: | |
| shutil.copy2(s, d) | |
| CHECKPOINT_PATH = LOCAL_MODEL_DIR | |
| else: | |
| print("Warning: Could not find a valid .config file in the snapshot. Falling back to local 'model'.") | |
| CHECKPOINT_PATH = "model" | |
| except Exception as e: | |
| print(f"Error during download or copy: {e}") | |
| CHECKPOINT_PATH = "model" | |
| # Load the model once at startup | |
| print(f"Initializing AspectExtractor with checkpoint: {CHECKPOINT_PATH}") | |
| # List files in initialization path for verification | |
| if os.path.exists(CHECKPOINT_PATH): | |
| print(f"Files in {CHECKPOINT_PATH}: {os.listdir(CHECKPOINT_PATH)}") | |
| model = ATEPC.AspectExtractor(checkpoint=CHECKPOINT_PATH) | |
| def predict_absa(text): | |
| if not text.strip(): | |
| return "Please enter some text to analyze." | |
| # Run prediction | |
| result = model.predict(text, print_result=False) | |
| if not result['aspect']: | |
| return "No aspects found in the input text." | |
| # Format results for display | |
| output = [] | |
| for aspect, sentiment in zip(result['aspect'], result['sentiment']): | |
| output.append({ | |
| "Aspect": aspect, | |
| "Sentiment": sentiment | |
| }) | |
| return output | |
| # 3. Create Gradio Interface | |
| demo = gr.Interface( | |
| fn=predict_absa, | |
| inputs=gr.Textbox( | |
| lines=3, | |
| placeholder="Enter a sentence here (e.g., 'The coffee was great but the price was too high.')", | |
| label="Input Text" | |
| ), | |
| outputs=gr.JSON(label="ABSA Results"), | |
| title="DeBERTa-v3 Aspect Based Sentiment Analysis", | |
| description="This demo uses a fine-tuned DeBERTa-v3 model to extract aspects and classify their sentiment polarities.", | |
| examples=[ | |
| ["The food was delicious but the service was extremely slow."], | |
| ["The battery life of this laptop is amazing, though the screen is a bit dim."], | |
| ["I love the interface, but the mobile app crashes frequently."] | |
| ], | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |