Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import PreTrainedTokenizerFast, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel | |
| import torch.nn as nn | |
| import datetime | |
| from huggingface_hub import hf_hub_download | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| BASE_MODEL = "Chamaka8/Serendip-LLM-CPT-SFT-v2" | |
| TOK_MODEL = "Chamaka8/serendib-tokenizer" | |
| NEWS_ADAPTER = "Chamaka8/SerendipLLM-news-classifier" | |
| WRITING_ADAPTER = "Chamaka8/SerendibLLM-v2-writing-head" | |
| SENTIMENT_ADAPTER = "Chamaka8/SerendibLLM-v2-sentiment-head" | |
| NEWS_CLASSES = ["Business", "Politics", "Entertainment", "Sports", "Technology"] | |
| WRITING_CLASSES = ["Academic", "Blog", "News", "Creative"] | |
| SENTIMENT_CLASSES = ["Positive", "Negative", "Neutral"] | |
| print(f"===== Startup at {datetime.datetime.now()} =====") | |
| bnb = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| print("Loading tokenizer...") | |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(TOK_MODEL, token=HF_TOKEN) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.pad_token_id = 0 | |
| tokenizer.padding_side = "right" | |
| print("Tokenizer ready") | |
| def load_model_and_head(adapter_repo, num_classes): | |
| print(f"Loading {adapter_repo}...") | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config=bnb, | |
| device_map="cpu", | |
| token=HF_TOKEN | |
| ) | |
| model = PeftModel.from_pretrained(base, adapter_repo, token=HF_TOKEN) | |
| model.eval() | |
| head_path = hf_hub_download( | |
| repo_id=adapter_repo, | |
| filename="classifier_head.pt", | |
| token=HF_TOKEN | |
| ) | |
| head = nn.Linear(4096, num_classes) | |
| head.load_state_dict(torch.load(head_path, map_location="cpu")) | |
| head.eval() | |
| print(f"{adapter_repo} ready") | |
| return model, head | |
| news_model, news_head = load_model_and_head(NEWS_ADAPTER, 5) | |
| writing_model, writing_head = load_model_and_head(WRITING_ADAPTER, 4) | |
| sentiment_model, sentiment_head = load_model_and_head(SENTIMENT_ADAPTER, 3) | |
| print("All models ready!") | |
| def run_inference(text, model, head): | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| max_length=256, | |
| truncation=True, | |
| padding="max_length" | |
| ) | |
| with torch.no_grad(): | |
| hidden = model( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| output_hidden_states=True | |
| ).hidden_states[-1] | |
| lengths = inputs["attention_mask"].sum(dim=1) - 1 | |
| last_hidden = hidden[0, lengths[0], :].unsqueeze(0).float() | |
| logits = head(last_hidden) | |
| pred = torch.argmax(logits, dim=1).item() | |
| return pred | |
| def classify_news(text): | |
| return NEWS_CLASSES[run_inference(text, news_model, news_head)] | |
| def classify_writing(text): | |
| return WRITING_CLASSES[run_inference(text, writing_model, writing_head)] | |
| def classify_sentiment(text): | |
| return SENTIMENT_CLASSES[run_inference(text, sentiment_model, sentiment_head)] | |
| with gr.Blocks(title="Serendib LLM Classifiers") as demo: | |
| gr.Markdown("## Serendib LLM — Sinhala Text Classifiers") | |
| with gr.Tab("News Category"): | |
| gr.Markdown("Classifies Sinhala news into: Business · Politics · Entertainment · Sports · Technology") | |
| news_input = gr.Textbox(label="Sinhala News Text", lines=5) | |
| news_btn = gr.Button("Classify", variant="primary") | |
| news_output = gr.Label(label="News Category") | |
| news_btn.click(fn=classify_news, inputs=news_input, outputs=news_output) | |
| with gr.Tab("Writing Style"): | |
| gr.Markdown("Classifies Sinhala text into: Academic · Blog · News · Creative") | |
| writing_input = gr.Textbox(label="Sinhala Text", lines=5) | |
| writing_btn = gr.Button("Classify", variant="primary") | |
| writing_output = gr.Label(label="Writing Style") | |
| writing_btn.click(fn=classify_writing, inputs=writing_input, outputs=writing_output) | |
| with gr.Tab("Sentiment"): | |
| gr.Markdown("Classifies Sinhala text into: Positive · Negative · Neutral") | |
| sentiment_input = gr.Textbox(label="Sinhala Text", lines=5) | |
| sentiment_btn = gr.Button("Classify", variant="primary") | |
| sentiment_output = gr.Label(label="Sentiment") | |
| sentiment_btn.click(fn=classify_sentiment, inputs=sentiment_input, outputs=sentiment_output) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |