kluvin's picture
Upload folder using huggingface_hub
3d366c1 verified
from flask import Flask, request, render_template
from transformers import pipeline
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
import polars as pl
import joblib
from pathlib import Path
import logging
import os
from time import perf_counter
from typing import Optional, Tuple
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
app = Flask(__name__)
CLASS_ID_TO_SENTIMENT = {
"0": "negative",
"1": "neutral",
"2": "positive"
}
def categorize_probability(probability: Optional[float]) -> Tuple[str, str, str]:
"""
Map a probability (0-1) to a qualitative label and associated CSS modifier.
Returns (label, css_class, display_value).
"""
if probability is None:
return ("Unknown", "probability-unknown", "N/A")
percent = max(0.0, min(probability * 100.0, 100.0))
if percent >= 80:
return ("Definitely", "probability-definitely", f"{percent:.0f}%")
if percent >= 60:
return ("Probably", "probability-probably", f"{percent:.0f}%")
return ("Maybe", "probability-maybe", f"{percent:.0f}%")
PRESET_TEXTS = [
"flower isn't beautiful",
"there is no more love. only pain.",
"one isn't a beauty, but two is a wondrous wonder",
"hvl is a fake university #uibforever"
]
# Use HF Spaces persistent storage if available, otherwise local cache
CACHE_DIR = Path(os.getenv("HF_HOME", ".")) / ".model_cache"
CACHE_DIR.mkdir(exist_ok=True)
logger.info("Loading BERTweet from HuggingFace Hub...")
bertweet_pipeline = pipeline("sentiment-analysis", model="kluvin/bertweet-tweet-sentiment")
logger.info("BERTweet loaded successfully")
# Define model configurations
model_configs = {
"Decision Tree": Pipeline([
("tfidf", TfidfVectorizer(max_features=2000, stop_words="english")),
("clf", DecisionTreeClassifier(max_depth=10, random_state=42))
]),
"Random Forest": Pipeline([
("tfidf", TfidfVectorizer(max_features=500, stop_words="english")),
("clf", RandomForestClassifier(n_estimators=100, random_state=42))
]),
"Logistic Regression": Pipeline([
("tfidf", TfidfVectorizer(max_features=2000, stop_words="english")),
("clf", LogisticRegression(max_iter=1000, random_state=42))
]),
"Linear SVM": Pipeline([
("tfidf", TfidfVectorizer(max_features=2000, stop_words="english")),
("clf", LinearSVC(random_state=42))
])
}
sklearn_pipelines = {}
cache_file = CACHE_DIR / "ml_models.joblib"
if cache_file.exists():
logger.info("Loading cached ML models...")
try:
sklearn_pipelines = joblib.load(cache_file)
logger.info("✓ Cached models loaded successfully!")
except Exception as e:
logger.error(f"Failed to load cache: {e}")
logger.info("Will retrain models...")
if not sklearn_pipelines:
logger.info("Loading training data and training ML models...")
splits = {'train': 'train.jsonl'}
df = pl.read_ndjson('hf://datasets/SetFit/tweet_sentiment_extraction/' + splits['train'])
X_train = df['text'].to_list()
y_train = df['label'].to_list()
logger.info("Training models...")
for model_name, sklearn_pipeline in model_configs.items():
logger.info(f" Training {model_name}...")
sklearn_pipeline.fit(X_train, y_train)
sklearn_pipelines[model_name] = sklearn_pipeline
logger.info("Saving models to cache...")
joblib.dump(sklearn_pipelines, cache_file)
logger.info(f"✓ Models cached at {cache_file}")
logger.info("All models loaded and ready!")
def render_model_result(model_name: str, sentiment_name: str, probability: float | None) -> str:
probability_label, probability_css, probability_value = categorize_probability(probability)
return f'''
<div class="model-result {sentiment_name}">
<h3>{model_name}</h3>
<p class="sentiment">{sentiment_name.capitalize()}</p>
<p class="confidence">
<span class="probability-badge {probability_css}">
<span class="probability-label">{probability_label}</span>
<span class="probability-value">{probability_value}</span>
</span>
</p>
</div>
'''
def build_results_markup(text_input: str) -> str:
inference_start = perf_counter()
results_html = ""
pipeline_output = bertweet_pipeline(text_input)[0]
predicted_class_id = pipeline_output['label']
probability = pipeline_output['score']
sentiment_name = CLASS_ID_TO_SENTIMENT[predicted_class_id]
logger.info(f"BERTweet prediction: {text_input} -> {sentiment_name} ({probability:.4f})")
results_html += render_model_result("BERTweet (Transformer)", sentiment_name, probability)
for model_name, sklearn_pipeline in sklearn_pipelines.items():
inputs = [text_input]
predicted_class = sklearn_pipeline.predict(inputs)[0]
classifier = sklearn_pipeline.named_steps['clf']
if hasattr(classifier, 'predict_proba'):
class_probabilities = sklearn_pipeline.predict_proba(inputs)[0]
probability = class_probabilities.max()
elif hasattr(classifier, 'decision_function'):
decision_scores = sklearn_pipeline.decision_function(inputs)[0]
probability = 1.0 / (1.0 + abs(decision_scores.min()))
else:
probability = None
sentiment_name = CLASS_ID_TO_SENTIMENT[str(predicted_class)]
results_html += render_model_result(model_name, sentiment_name, probability)
elapsed_ms = (perf_counter() - inference_start) * 1000
return (
f'<aside class="inference-meta">Inference time: {elapsed_ms:.0f} ms</aside>'
f'<div class="results-grid">{results_html}</div>'
)
@app.route('/')
def home():
default_text = PRESET_TEXTS[0]
initial_results_html = ""
try:
logger.info("Precomputing initial classification for default preset...")
initial_results_html = build_results_markup(default_text)
except Exception as e:
logger.error(f"Failed to precompute initial results: {e}", exc_info=True)
return render_template(
'index.html',
presets=PRESET_TEXTS,
default_preset=default_text,
initial_results=initial_results_html
)
@app.route('/classify', methods=['POST'])
def classify():
try:
text_input = request.form['text']
cleaned_text = text_input.strip()
if not cleaned_text:
return '''
<div class="result error">
<h2>Error: Please enter some text</h2>
</div>
'''
logger.info(f"Classifying: {cleaned_text[:50]}...")
return build_results_markup(cleaned_text)
except Exception as e:
logger.error(f"Classification error: {e}", exc_info=True)
return f'''
<div class="result error">
<h2>Error: {e}</h2>
</div>
'''
if __name__ == "__main__":
if app.debug:
logger.setLevel(logging.DEBUG)
app.run(debug=True)