Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import streamlit as st | |
| from pydantic import BaseModel | |
| from fastapi import FastAPI, Request | |
| from transformers import AutoTokenizer, AutoModel | |
| from peft import PeftModel | |
| # Get the token from environment variable (optional) | |
| hf_token = os.environ.get("HF_TOKEN") | |
| # Define model IDs | |
| adapter_model_id = "seniormgt/arabicmgt-test" | |
| base_model_id = "Alibaba-NLP/gte-multilingual-base" | |
| # Define your model | |
| class GTEClassifier(nn.Module): | |
| def __init__(self, model_name=base_model_id): | |
| super(GTEClassifier, self).__init__() | |
| self.base_model = AutoModel.from_pretrained(model_name, trust_remote_code=True) | |
| self.config = self.base_model.config | |
| self.pooler = nn.Linear(self.config.hidden_size, self.config.hidden_size) | |
| self.pooler_activation = nn.Tanh() | |
| self.dropout = nn.Dropout(0.0) | |
| self.classifier = nn.Linear(self.config.hidden_size, 1) | |
| self.loss_fn = nn.BCEWithLogitsLoss() | |
| def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, **kwargs): | |
| if inputs_embeds is not None: | |
| outputs = self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask) | |
| else: | |
| outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs.last_hidden_state[:, 0, :] | |
| pooled_output = self.pooler(pooled_output) | |
| pooled_output = self.pooler_activation(pooled_output) | |
| logits = self.classifier(self.dropout(pooled_output)).squeeze(-1) | |
| loss = None | |
| if labels is not None: | |
| labels = labels.float() | |
| loss = self.loss_fn(logits, labels) | |
| return {"loss": loss, "logits": logits} | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained(adapter_model_id, token=hf_token, trust_remote_code=True) | |
| base_model = GTEClassifier() | |
| peft_model = PeftModel.from_pretrained(base_model, adapter_model_id, token=hf_token) | |
| # peft_model.eval() | |
| # Define prediction | |
| def classify_text(text): | |
| inputs = tokenizer(text, max_length=512, padding=True, return_attention_mask=True, return_tensors="pt", truncation=True) | |
| input_ids = inputs['input_ids'] | |
| attention_mask = inputs['attention_mask'] | |
| with torch.no_grad(): | |
| outputs = peft_model(input_ids=input_ids, attention_mask=attention_mask) | |
| logits = outputs["logits"] | |
| probs = torch.sigmoid(logits).cpu().numpy().squeeze() | |
| pred_label = int(probs >= 0.5) | |
| return {"label": str(pred_label), "confidence": float(probs)} | |
| # 🔹 Streamlit UI | |
| st.title("Text Classification (MGT Detection)") | |
| text = st.text_area("Enter text", height=150) | |
| if st.button("Classify") and text.strip(): | |
| result = classify_text(text) | |
| st.json(result) | |
| # 🔹 FastAPI endpoint | |
| app = FastAPI() | |
| class Input(BaseModel): | |
| data: list | |
| async def predict(request: Request): | |
| payload = await request.json() | |
| text = payload["data"][0]["text"] | |
| result = classify_text(text) | |
| return {"data": [result]} | |