Spaces:
Sleeping
Sleeping
| # ----------------------------- | |
| # app.py | |
| # ----------------------------- | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import torch.nn.functional as F | |
| # ----------------------------- | |
| # Load the fine-tuned model | |
| # ----------------------------- | |
| MODEL_REPO = "umarfarzan/deberta-best-clipworthiness" # replace with your HF repo | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO) | |
| model.eval() | |
| # ----------------------------- | |
| # Prediction function with colored output | |
| # ----------------------------- | |
| def predict_viral(transcript): | |
| """ | |
| Takes a transcript and predicts Viral or Not Viral | |
| """ | |
| inputs = tokenizer(transcript, padding="max_length", truncation=True, max_length=512, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = F.softmax(logits, dim=-1) | |
| pred_label = torch.argmax(probs, dim=1).item() | |
| pred_prob = probs[0, pred_label].item() | |
| if pred_label == 1: | |
| label_str = f"<span style='color:green;font-weight:bold;'>Viral</span> ({pred_prob*100:.2f}% confidence)" | |
| else: | |
| label_str = f"<span style='color:red;font-weight:bold;'>Not Viral</span> ({pred_prob*100:.2f}% confidence)" | |
| return label_str | |
| # ----------------------------- | |
| # Gradio Interface | |
| # ----------------------------- | |
| with gr.Blocks(title="Acliptic - Revolutionising The Future Of Clipping") as demo: | |
| gr.Markdown("<h1 style='text-align:center;color:#4B0082;'>Acliptic</h1>") | |
| gr.Markdown("<h3 style='text-align:center;color:#6A5ACD;'>Revolutionising The Future Of Clipping</h3>") | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| transcript_input = gr.Textbox( | |
| lines=10, | |
| placeholder="Paste your transcript here...", | |
| label="Transcript" | |
| ) | |
| predict_btn = gr.Button("Predict Viral Potential") | |
| result_output = gr.HTML(label="Prediction") | |
| predict_btn.click(fn=predict_viral, inputs=transcript_input, outputs=result_output) | |
| # Launch the interface | |
| demo.launch() | |