CoinPulse / app.py
houmanrajabi's picture
Create app.py
a0247f0 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import Dataset, load_dataset
from huggingface_hub import HfApi
import os
from datetime import datetime
# Load model and tokenizer
MODEL_PATH = "houmanrajabi/CoinPulse"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
# Set model to evaluation mode
model.eval()
# Label mapping from config
id2label = {0: 'negative', 1: 'neutral', 2: 'positive'}
# HF Dataset configuration for flagging
HF_TOKEN = os.getenv("HF_TOKEN")
DATASET_REPO = "houmanrajabi/coinpulse-flagged-data"
def predict_sentiment(text, temperature=2.0):
if not text.strip():
return {"error": "Please enter some text"}
# Tokenize input with extended vocabulary
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Get prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits / temperature
probs = torch.softmax(logits, dim=1)[0]
# Create results dictionary
results = {
id2label[0]: float(probs[0]),
id2label[1]: float(probs[1]),
id2label[2]: float(probs[2])
}
return results
def save_flagged_data(text_input, temperature_input, prediction_output,
sentiment_label, issue_type):
"""Save flagged data to HF Dataset or CSV"""
try:
# Combine flags
flags = []
if sentiment_label:
flags.append(sentiment_label)
if issue_type:
flags.append(issue_type)
flag_combined = " + ".join(flags) if flags else "flagged"
if HF_TOKEN:
try:
# Load existing dataset or create new one
try:
dataset = load_dataset(DATASET_REPO, split="train", token=HF_TOKEN)
data_dict = dataset.to_dict()
except Exception as e:
print(f"Creating new dataset error: {e}")
data_dict = {
"text": [],
"temperature": [],
"prediction": [],
"sentiment_label": [],
"issue_type": [],
"flag_combined": [],
"timestamp": []
}
# Add new flagged data
data_dict["text"].append(str(text_input))
data_dict["temperature"].append(float(temperature_input))
data_dict["prediction"].append(str(prediction_output))
data_dict["sentiment_label"].append(str(sentiment_label) if sentiment_label else "")
data_dict["issue_type"].append(str(issue_type) if issue_type else "")
data_dict["flag_combined"].append(flag_combined)
data_dict["timestamp"].append(datetime.now().isoformat())
# Create and push dataset
new_dataset = Dataset.from_dict(data_dict)
new_dataset.push_to_hub(
DATASET_REPO,
token=HF_TOKEN,
private=True
)
return f"Successfully flagged as: {flag_combined}"
except Exception as e:
print(f"HF Dataset error: {e}")
# Fall through to CSV fallback
# CSV Fallback
import csv
os.makedirs("flagged_data", exist_ok=True)
csv_path = "flagged_data/flags.csv"
file_exists = os.path.exists(csv_path)
with open(csv_path, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
if not file_exists:
writer.writerow([
"text", "temperature", "prediction",
"sentiment_label", "issue_type", "flag_combined", "timestamp"
])
writer.writerow([
text_input, temperature_input, prediction_output,
sentiment_label if sentiment_label else "",
issue_type if issue_type else "",
flag_combined,
datetime.now().isoformat()
])
return f"Flagged as: {flag_combined} (saved to CSV)"
except Exception as e:
return f"Error saving flag: {str(e)}"
# Example texts
examples = [
["Bitcoin reaches new all-time high amid institutional adoption"],
["Major cryptocurrency exchange faces security breach, users advised to withdraw funds"],
["Ethereum network processes steady transaction volume with no significant changes"],
["Solana announces new partnership with leading DeFi protocol"],
["NFT market sees declining volumes as collectors wait for next bull run"],
["Cardano's latest upgrade brings enhanced smart contract capabilities"],
["Dogecoin community rallies behind charitable initiatives"],
["PancakeSwap introduces new liquidity mining incentives"],
["Stock prices plummeted after the disappointing earnings report."],
["The quarterly results were in line with market forecasts."]
]
# Create custom Gradio interface with Blocks for multi-select flagging
with gr.Blocks(theme=gr.themes.Soft(), title="CoinPulse Sentiment Analysis") as demo:
gr.Markdown("""
# πŸͺ™ CoinPulse: Cryptocurrency Sentiment Analysis
This model analyzes sentiment in cryptocurrency-related text using a fine-tuned FinBERT model
with an **extended tokenizer vocabulary**.
**Key Features:**
- **Extended Vocabulary**: 520 crypto-specific tokens added (e.g., bitcoin, ethereum, defi, nft, solana)
- **Total Vocabulary Size**: 31,024 tokens
- **Base Model**: ProsusAI/finbert
- **Fine-tuned**: On cryptocurrency news and social media data
- **Classification**: Positive, Negative, Neutral
**Performance:**
- Test Accuracy: 93.00%
- Test F1 (weighted): 92.53%
- Best Validation F1: 92.86%
""")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
lines=5,
placeholder="Enter cryptocurrency-related text here...",
label="Input Text"
)
temperature_input = gr.Number(
label="Temperature",
value=2.0,
precision=1,
step=0.1,
minimum=0.1,
maximum=10.0
)
predict_btn = gr.Button("Analyze Sentiment", variant="primary")
with gr.Column(scale=1):
prediction_output = gr.Label(num_top_classes=3, label="Sentiment Prediction")
gr.Markdown("---")
gr.Markdown("""
### πŸ“Œ Help Improve This Model
If the prediction is incorrect or interesting, you can flag it:
- **Correct Sentiment**: What should the sentiment be?
- **Issue Type**: What's wrong or interesting about this prediction?
You can select one, both, or neither.
Don't forget to submit your selection by clicking 🚩 Flag This Prediction
""")
with gr.Row():
sentiment_flag = gr.Radio(
choices=["positive", "negative", "neutral"],
label="Correct Sentiment",
value=None
)
issue_flag = gr.Radio(
choices=["incorrect prediction", "offensive content", "interesting case", "edge case"],
label="Issue Type",
value=None
)
with gr.Row():
flag_btn = gr.Button("🚩 Flag This Prediction", variant="secondary")
flag_status = gr.Textbox(label="Flag Status", interactive=False)
gr.Markdown("---")
gr.Markdown("### πŸ’‘ Try These Examples:")
gr.Examples(
examples=examples,
inputs=[text_input],
outputs=prediction_output,
fn=predict_sentiment,
cache_examples=False
)
# Connect the predict button
predict_btn.click(
fn=predict_sentiment,
inputs=[text_input],
outputs=prediction_output
)
# Connect the flag button
flag_btn.click(
fn=save_flagged_data,
inputs=[text_input, temperature_input, prediction_output, sentiment_flag, issue_flag],
outputs=flag_status
)
if __name__ == "__main__":
print("Starting CoinPulse Sentiment Analysis...")
if HF_TOKEN:
print(f"Flagged data will be saved to: {DATASET_REPO}")
else:
print("Flagged data will be saved locally to: flagged_data/")
demo.launch()