Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from datasets import Dataset, load_dataset | |
| from huggingface_hub import HfApi | |
| import os | |
| from datetime import datetime | |
| # Load model and tokenizer | |
| MODEL_PATH = "houmanrajabi/CoinPulse" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) | |
| # Set model to evaluation mode | |
| model.eval() | |
| # Label mapping from config | |
| id2label = {0: 'negative', 1: 'neutral', 2: 'positive'} | |
| # HF Dataset configuration for flagging | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| DATASET_REPO = "houmanrajabi/coinpulse-flagged-data" | |
| def predict_sentiment(text, temperature=2.0): | |
| if not text.strip(): | |
| return {"error": "Please enter some text"} | |
| # Tokenize input with extended vocabulary | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ) | |
| # Get prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits / temperature | |
| probs = torch.softmax(logits, dim=1)[0] | |
| # Create results dictionary | |
| results = { | |
| id2label[0]: float(probs[0]), | |
| id2label[1]: float(probs[1]), | |
| id2label[2]: float(probs[2]) | |
| } | |
| return results | |
| def save_flagged_data(text_input, temperature_input, prediction_output, | |
| sentiment_label, issue_type): | |
| """Save flagged data to HF Dataset or CSV""" | |
| try: | |
| # Combine flags | |
| flags = [] | |
| if sentiment_label: | |
| flags.append(sentiment_label) | |
| if issue_type: | |
| flags.append(issue_type) | |
| flag_combined = " + ".join(flags) if flags else "flagged" | |
| if HF_TOKEN: | |
| try: | |
| # Load existing dataset or create new one | |
| try: | |
| dataset = load_dataset(DATASET_REPO, split="train", token=HF_TOKEN) | |
| data_dict = dataset.to_dict() | |
| except Exception as e: | |
| print(f"Creating new dataset error: {e}") | |
| data_dict = { | |
| "text": [], | |
| "temperature": [], | |
| "prediction": [], | |
| "sentiment_label": [], | |
| "issue_type": [], | |
| "flag_combined": [], | |
| "timestamp": [] | |
| } | |
| # Add new flagged data | |
| data_dict["text"].append(str(text_input)) | |
| data_dict["temperature"].append(float(temperature_input)) | |
| data_dict["prediction"].append(str(prediction_output)) | |
| data_dict["sentiment_label"].append(str(sentiment_label) if sentiment_label else "") | |
| data_dict["issue_type"].append(str(issue_type) if issue_type else "") | |
| data_dict["flag_combined"].append(flag_combined) | |
| data_dict["timestamp"].append(datetime.now().isoformat()) | |
| # Create and push dataset | |
| new_dataset = Dataset.from_dict(data_dict) | |
| new_dataset.push_to_hub( | |
| DATASET_REPO, | |
| token=HF_TOKEN, | |
| private=True | |
| ) | |
| return f"Successfully flagged as: {flag_combined}" | |
| except Exception as e: | |
| print(f"HF Dataset error: {e}") | |
| # Fall through to CSV fallback | |
| # CSV Fallback | |
| import csv | |
| os.makedirs("flagged_data", exist_ok=True) | |
| csv_path = "flagged_data/flags.csv" | |
| file_exists = os.path.exists(csv_path) | |
| with open(csv_path, "a", newline="", encoding="utf-8") as f: | |
| writer = csv.writer(f) | |
| if not file_exists: | |
| writer.writerow([ | |
| "text", "temperature", "prediction", | |
| "sentiment_label", "issue_type", "flag_combined", "timestamp" | |
| ]) | |
| writer.writerow([ | |
| text_input, temperature_input, prediction_output, | |
| sentiment_label if sentiment_label else "", | |
| issue_type if issue_type else "", | |
| flag_combined, | |
| datetime.now().isoformat() | |
| ]) | |
| return f"Flagged as: {flag_combined} (saved to CSV)" | |
| except Exception as e: | |
| return f"Error saving flag: {str(e)}" | |
| # Example texts | |
| examples = [ | |
| ["Bitcoin reaches new all-time high amid institutional adoption"], | |
| ["Major cryptocurrency exchange faces security breach, users advised to withdraw funds"], | |
| ["Ethereum network processes steady transaction volume with no significant changes"], | |
| ["Solana announces new partnership with leading DeFi protocol"], | |
| ["NFT market sees declining volumes as collectors wait for next bull run"], | |
| ["Cardano's latest upgrade brings enhanced smart contract capabilities"], | |
| ["Dogecoin community rallies behind charitable initiatives"], | |
| ["PancakeSwap introduces new liquidity mining incentives"], | |
| ["Stock prices plummeted after the disappointing earnings report."], | |
| ["The quarterly results were in line with market forecasts."] | |
| ] | |
| # Create custom Gradio interface with Blocks for multi-select flagging | |
| with gr.Blocks(theme=gr.themes.Soft(), title="CoinPulse Sentiment Analysis") as demo: | |
| gr.Markdown(""" | |
| # πͺ CoinPulse: Cryptocurrency Sentiment Analysis | |
| This model analyzes sentiment in cryptocurrency-related text using a fine-tuned FinBERT model | |
| with an **extended tokenizer vocabulary**. | |
| **Key Features:** | |
| - **Extended Vocabulary**: 520 crypto-specific tokens added (e.g., bitcoin, ethereum, defi, nft, solana) | |
| - **Total Vocabulary Size**: 31,024 tokens | |
| - **Base Model**: ProsusAI/finbert | |
| - **Fine-tuned**: On cryptocurrency news and social media data | |
| - **Classification**: Positive, Negative, Neutral | |
| **Performance:** | |
| - Test Accuracy: 93.00% | |
| - Test F1 (weighted): 92.53% | |
| - Best Validation F1: 92.86% | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_input = gr.Textbox( | |
| lines=5, | |
| placeholder="Enter cryptocurrency-related text here...", | |
| label="Input Text" | |
| ) | |
| temperature_input = gr.Number( | |
| label="Temperature", | |
| value=2.0, | |
| precision=1, | |
| step=0.1, | |
| minimum=0.1, | |
| maximum=10.0 | |
| ) | |
| predict_btn = gr.Button("Analyze Sentiment", variant="primary") | |
| with gr.Column(scale=1): | |
| prediction_output = gr.Label(num_top_classes=3, label="Sentiment Prediction") | |
| gr.Markdown("---") | |
| gr.Markdown(""" | |
| ### π Help Improve This Model | |
| If the prediction is incorrect or interesting, you can flag it: | |
| - **Correct Sentiment**: What should the sentiment be? | |
| - **Issue Type**: What's wrong or interesting about this prediction? | |
| You can select one, both, or neither. | |
| Don't forget to submit your selection by clicking π© Flag This Prediction | |
| """) | |
| with gr.Row(): | |
| sentiment_flag = gr.Radio( | |
| choices=["positive", "negative", "neutral"], | |
| label="Correct Sentiment", | |
| value=None | |
| ) | |
| issue_flag = gr.Radio( | |
| choices=["incorrect prediction", "offensive content", "interesting case", "edge case"], | |
| label="Issue Type", | |
| value=None | |
| ) | |
| with gr.Row(): | |
| flag_btn = gr.Button("π© Flag This Prediction", variant="secondary") | |
| flag_status = gr.Textbox(label="Flag Status", interactive=False) | |
| gr.Markdown("---") | |
| gr.Markdown("### π‘ Try These Examples:") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[text_input], | |
| outputs=prediction_output, | |
| fn=predict_sentiment, | |
| cache_examples=False | |
| ) | |
| # Connect the predict button | |
| predict_btn.click( | |
| fn=predict_sentiment, | |
| inputs=[text_input], | |
| outputs=prediction_output | |
| ) | |
| # Connect the flag button | |
| flag_btn.click( | |
| fn=save_flagged_data, | |
| inputs=[text_input, temperature_input, prediction_output, sentiment_flag, issue_flag], | |
| outputs=flag_status | |
| ) | |
| if __name__ == "__main__": | |
| print("Starting CoinPulse Sentiment Analysis...") | |
| if HF_TOKEN: | |
| print(f"Flagged data will be saved to: {DATASET_REPO}") | |
| else: | |
| print("Flagged data will be saved locally to: flagged_data/") | |
| demo.launch() | |