File size: 8,768 Bytes
a0247f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import Dataset, load_dataset
from huggingface_hub import HfApi
import os
from datetime import datetime

# Load model and tokenizer
MODEL_PATH = "houmanrajabi/CoinPulse"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)

# Set model to evaluation mode
model.eval()

# Label mapping from config
id2label = {0: 'negative', 1: 'neutral', 2: 'positive'}

# HF Dataset configuration for flagging
HF_TOKEN = os.getenv("HF_TOKEN")
DATASET_REPO = "houmanrajabi/coinpulse-flagged-data"

def predict_sentiment(text, temperature=2.0):
    if not text.strip():
        return {"error": "Please enter some text"}

    # Tokenize input with extended vocabulary
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=512,
        padding=True
    )

    # Get prediction
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits / temperature
        probs = torch.softmax(logits, dim=1)[0]

    # Create results dictionary
    results = {
        id2label[0]: float(probs[0]),
        id2label[1]: float(probs[1]),
        id2label[2]: float(probs[2])
    }

    return results

def save_flagged_data(text_input, temperature_input, prediction_output, 
                      sentiment_label, issue_type):
    """Save flagged data to HF Dataset or CSV"""
    try:
        # Combine flags
        flags = []
        if sentiment_label:
            flags.append(sentiment_label)
        if issue_type:
            flags.append(issue_type)
        flag_combined = " + ".join(flags) if flags else "flagged"
        
        if HF_TOKEN:
            try:
                # Load existing dataset or create new one
                try:
                    dataset = load_dataset(DATASET_REPO, split="train", token=HF_TOKEN)
                    data_dict = dataset.to_dict()
                except Exception as e:
                    print(f"Creating new dataset error: {e}")
                    data_dict = {
                        "text": [],
                        "temperature": [],
                        "prediction": [],
                        "sentiment_label": [],
                        "issue_type": [],
                        "flag_combined": [],
                        "timestamp": []
                    }
                
                # Add new flagged data
                data_dict["text"].append(str(text_input))
                data_dict["temperature"].append(float(temperature_input))
                data_dict["prediction"].append(str(prediction_output))
                data_dict["sentiment_label"].append(str(sentiment_label) if sentiment_label else "")
                data_dict["issue_type"].append(str(issue_type) if issue_type else "")
                data_dict["flag_combined"].append(flag_combined)
                data_dict["timestamp"].append(datetime.now().isoformat())
                
                # Create and push dataset
                new_dataset = Dataset.from_dict(data_dict)
                new_dataset.push_to_hub(
                    DATASET_REPO,
                    token=HF_TOKEN,
                    private=True
                )
                
                return f"Successfully flagged as: {flag_combined}"
                
            except Exception as e:
                print(f"HF Dataset error: {e}")
                # Fall through to CSV fallback
        
        # CSV Fallback
        import csv
        os.makedirs("flagged_data", exist_ok=True)
        csv_path = "flagged_data/flags.csv"
        file_exists = os.path.exists(csv_path)
        
        with open(csv_path, "a", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            if not file_exists:
                writer.writerow([
                    "text", "temperature", "prediction", 
                    "sentiment_label", "issue_type", "flag_combined", "timestamp"
                ])
            writer.writerow([
                text_input, temperature_input, prediction_output,
                sentiment_label if sentiment_label else "",
                issue_type if issue_type else "",
                flag_combined,
                datetime.now().isoformat()
            ])
        
        return f"Flagged as: {flag_combined} (saved to CSV)"
        
    except Exception as e:
        return f"Error saving flag: {str(e)}"

# Example texts
examples = [
    ["Bitcoin reaches new all-time high amid institutional adoption"],
    ["Major cryptocurrency exchange faces security breach, users advised to withdraw funds"],
    ["Ethereum network processes steady transaction volume with no significant changes"],
    ["Solana announces new partnership with leading DeFi protocol"],
    ["NFT market sees declining volumes as collectors wait for next bull run"],
    ["Cardano's latest upgrade brings enhanced smart contract capabilities"],
    ["Dogecoin community rallies behind charitable initiatives"],
    ["PancakeSwap introduces new liquidity mining incentives"],
    ["Stock prices plummeted after the disappointing earnings report."],
    ["The quarterly results were in line with market forecasts."]
]

# Create custom Gradio interface with Blocks for multi-select flagging
with gr.Blocks(theme=gr.themes.Soft(), title="CoinPulse Sentiment Analysis") as demo:
    
    gr.Markdown("""
    # πŸͺ™ CoinPulse: Cryptocurrency Sentiment Analysis
    
    This model analyzes sentiment in cryptocurrency-related text using a fine-tuned FinBERT model
    with an **extended tokenizer vocabulary**.

    **Key Features:**
    - **Extended Vocabulary**: 520 crypto-specific tokens added (e.g., bitcoin, ethereum, defi, nft, solana)
    - **Total Vocabulary Size**: 31,024 tokens
    - **Base Model**: ProsusAI/finbert
    - **Fine-tuned**: On cryptocurrency news and social media data
    - **Classification**: Positive, Negative, Neutral

    **Performance:**
    - Test Accuracy: 93.00%
    - Test F1 (weighted): 92.53%
    - Best Validation F1: 92.86%
    """)
    
    with gr.Row():
        with gr.Column(scale=2):
            text_input = gr.Textbox(
                lines=5,
                placeholder="Enter cryptocurrency-related text here...",
                label="Input Text"
            )
            temperature_input = gr.Number(
                label="Temperature",
                value=2.0,
                precision=1,
                step=0.1,
                minimum=0.1,
                maximum=10.0
            )
            predict_btn = gr.Button("Analyze Sentiment", variant="primary")
        
        with gr.Column(scale=1):
            prediction_output = gr.Label(num_top_classes=3, label="Sentiment Prediction")
    
    gr.Markdown("---")
    gr.Markdown("""
    ### πŸ“Œ Help Improve This Model
    
    If the prediction is incorrect or interesting, you can flag it:
    - **Correct Sentiment**: What should the sentiment be?
    - **Issue Type**: What's wrong or interesting about this prediction?
    
    You can select one, both, or neither. 
    
    Don't forget to submit your selection by clicking 🚩 Flag This Prediction
    """)
    
    with gr.Row():
        sentiment_flag = gr.Radio(
            choices=["positive", "negative", "neutral"],
            label="Correct Sentiment",
            value=None
        )
        issue_flag = gr.Radio(
            choices=["incorrect prediction", "offensive content", "interesting case", "edge case"],
            label="Issue Type",
            value=None
        )
    
    with gr.Row():
        flag_btn = gr.Button("🚩 Flag This Prediction", variant="secondary")
        flag_status = gr.Textbox(label="Flag Status", interactive=False)
    
    gr.Markdown("---")
    gr.Markdown("### πŸ’‘ Try These Examples:")
    
    gr.Examples(
        examples=examples,
        inputs=[text_input],
        outputs=prediction_output,
        fn=predict_sentiment,
        cache_examples=False
    )
    
    # Connect the predict button
    predict_btn.click(
        fn=predict_sentiment,
        inputs=[text_input],
        outputs=prediction_output
    )
    
    # Connect the flag button
    flag_btn.click(
        fn=save_flagged_data,
        inputs=[text_input, temperature_input, prediction_output, sentiment_flag, issue_flag],
        outputs=flag_status
    )

if __name__ == "__main__":
    print("Starting CoinPulse Sentiment Analysis...")
    if HF_TOKEN:
        print(f"Flagged data will be saved to: {DATASET_REPO}")
    else:
        print("Flagged data will be saved locally to: flagged_data/")
    
    demo.launch()