Chamaka8's picture
Upload app.py with huggingface_hub
6154c88 verified
import os
import torch
import gradio as gr
from transformers import PreTrainedTokenizerFast, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch.nn as nn
import datetime
from huggingface_hub import hf_hub_download
HF_TOKEN = os.environ.get("HF_TOKEN")
BASE_MODEL = "Chamaka8/Serendip-LLM-CPT-SFT-v2"
TOK_MODEL = "Chamaka8/serendib-tokenizer"
NEWS_ADAPTER = "Chamaka8/SerendipLLM-news-classifier"
WRITING_ADAPTER = "Chamaka8/SerendibLLM-v2-writing-head"
SENTIMENT_ADAPTER = "Chamaka8/SerendibLLM-v2-sentiment-head"
NEWS_CLASSES = ["Business", "Politics", "Entertainment", "Sports", "Technology"]
WRITING_CLASSES = ["Academic", "Blog", "News", "Creative"]
SENTIMENT_CLASSES = ["Positive", "Negative", "Neutral"]
print(f"===== Startup at {datetime.datetime.now()} =====")
bnb = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
print("Loading tokenizer...")
tokenizer = PreTrainedTokenizerFast.from_pretrained(TOK_MODEL, token=HF_TOKEN)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = 0
tokenizer.padding_side = "right"
print("Tokenizer ready")
def load_model_and_head(adapter_repo, num_classes):
print(f"Loading {adapter_repo}...")
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb,
device_map="cpu",
token=HF_TOKEN
)
model = PeftModel.from_pretrained(base, adapter_repo, token=HF_TOKEN)
model.eval()
head_path = hf_hub_download(
repo_id=adapter_repo,
filename="classifier_head.pt",
token=HF_TOKEN
)
head = nn.Linear(4096, num_classes)
head.load_state_dict(torch.load(head_path, map_location="cpu"))
head.eval()
print(f"{adapter_repo} ready")
return model, head
news_model, news_head = load_model_and_head(NEWS_ADAPTER, 5)
writing_model, writing_head = load_model_and_head(WRITING_ADAPTER, 4)
sentiment_model, sentiment_head = load_model_and_head(SENTIMENT_ADAPTER, 3)
print("All models ready!")
def run_inference(text, model, head):
inputs = tokenizer(
text,
return_tensors="pt",
max_length=256,
truncation=True,
padding="max_length"
)
with torch.no_grad():
hidden = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
output_hidden_states=True
).hidden_states[-1]
lengths = inputs["attention_mask"].sum(dim=1) - 1
last_hidden = hidden[0, lengths[0], :].unsqueeze(0).float()
logits = head(last_hidden)
pred = torch.argmax(logits, dim=1).item()
return pred
def classify_news(text):
return NEWS_CLASSES[run_inference(text, news_model, news_head)]
def classify_writing(text):
return WRITING_CLASSES[run_inference(text, writing_model, writing_head)]
def classify_sentiment(text):
return SENTIMENT_CLASSES[run_inference(text, sentiment_model, sentiment_head)]
with gr.Blocks(title="Serendib LLM Classifiers") as demo:
gr.Markdown("## Serendib LLM — Sinhala Text Classifiers")
with gr.Tab("News Category"):
gr.Markdown("Classifies Sinhala news into: Business · Politics · Entertainment · Sports · Technology")
news_input = gr.Textbox(label="Sinhala News Text", lines=5)
news_btn = gr.Button("Classify", variant="primary")
news_output = gr.Label(label="News Category")
news_btn.click(fn=classify_news, inputs=news_input, outputs=news_output)
with gr.Tab("Writing Style"):
gr.Markdown("Classifies Sinhala text into: Academic · Blog · News · Creative")
writing_input = gr.Textbox(label="Sinhala Text", lines=5)
writing_btn = gr.Button("Classify", variant="primary")
writing_output = gr.Label(label="Writing Style")
writing_btn.click(fn=classify_writing, inputs=writing_input, outputs=writing_output)
with gr.Tab("Sentiment"):
gr.Markdown("Classifies Sinhala text into: Positive · Negative · Neutral")
sentiment_input = gr.Textbox(label="Sinhala Text", lines=5)
sentiment_btn = gr.Button("Classify", variant="primary")
sentiment_output = gr.Label(label="Sentiment")
sentiment_btn.click(fn=classify_sentiment, inputs=sentiment_input, outputs=sentiment_output)
demo.launch(server_name="0.0.0.0", server_port=7860)