Spaces:
Sleeping
Sleeping
| import torch | |
| import json | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig | |
| from peft import LoraConfig, get_peft_model, PeftModel | |
| from araberta_setting.modeling_bilstm_crf import BERT_BiLSTM_CRF | |
| from seq2seq_inference import infer_t5_prompt | |
| from huggingface_hub import hf_hub_download | |
| # Define supported models and their adapter IDs | |
| MODEL_OPTIONS = { | |
| "Araberta": { | |
| "base": "asmashayea/absa-araberta", | |
| "adapter": "asmashayea/absa-araberta" | |
| }, | |
| "mT5": { | |
| "base": "google/mt5-base", | |
| "adapter": "asmashayea/mt4-absa" | |
| }, | |
| "mBART": { | |
| "base": "facebook/mbart-large-50-many-to-many-mmt", | |
| "adapter": "asmashayea/mbart-absa" | |
| }, | |
| "GPT3.5": { | |
| "base": "bigscience/bloom-560m", # placeholder | |
| "adapter": "asmashayea/gpt-absa" | |
| }, | |
| "GPT4o": { | |
| "base": "bigscience/bloom-560m", # placeholder | |
| "adapter": "asmashayea/gpt-absa" | |
| } | |
| } | |
| cached_models = {} | |
| def load_araberta(): | |
| path = "asmashayea/absa-arabert" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(path) | |
| base_model = AutoModel.from_pretrained(path) | |
| # Load LoRA adapter | |
| lora_config = LoraConfig.from_pretrained(path) | |
| lora_model = get_peft_model(base_model, lora_config) | |
| # Download CRF head from Hub | |
| local_pt = hf_hub_download(repo_id=path, filename="bilstm_crf_head.pt") | |
| config = AutoConfig.from_pretrained(path) | |
| model = BERT_BiLSTM_CRF(lora_model, config) | |
| # Always map to current device | |
| state_dict = torch.load(local_pt, map_location=torch.device(device)) | |
| model.load_state_dict(state_dict) | |
| model.to(device).eval() | |
| cached_models["Araberta"] = (tokenizer, model) | |
| return tokenizer, model | |
| def infer_araberta(text): | |
| if "Araberta" not in cached_models: | |
| tokenizer, model = load_araberta() | |
| else: | |
| tokenizer, model = cached_models["Araberta"] | |
| device = next(model.parameters()).device | |
| inputs = tokenizer( | |
| text, | |
| return_tensors='pt', | |
| truncation=True, | |
| padding='max_length', | |
| max_length=128 | |
| ) | |
| input_ids = inputs['input_ids'].to(device) | |
| attention_mask = inputs['attention_mask'].to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| predicted_ids = outputs['logits'][0].cpu().tolist() | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu()) | |
| predicted_labels = [model.config.id2label.get(p, 'O') for p in predicted_ids] | |
| clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens] | |
| clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens] | |
| # Horizontal token:label pairs | |
| pairs = [f"{token}: {label}" for token, label in zip(clean_tokens, clean_labels)] | |
| horizontal_output = " | ".join(pairs) | |
| # Group into aspect spans | |
| aspects = [] | |
| current_tokens, current_sentiment = [], None | |
| for token, label in zip(clean_tokens, clean_labels): | |
| if label.startswith("B-"): | |
| if current_tokens: | |
| aspects.append({ | |
| "aspect": " ".join(current_tokens).replace("##", ""), | |
| "sentiment": current_sentiment | |
| }) | |
| current_tokens = [token] | |
| current_sentiment = label.split("-")[1] | |
| elif label.startswith("I-") and current_sentiment == label.split("-")[1]: | |
| current_tokens.append(token) | |
| else: | |
| if current_tokens: | |
| aspects.append({ | |
| "aspect": " ".join(current_tokens).replace("##", ""), | |
| "sentiment": current_sentiment | |
| }) | |
| current_tokens, current_sentiment = [], None | |
| if current_tokens: | |
| aspects.append({ | |
| "aspect": " ".join(current_tokens).replace("##", ""), | |
| "sentiment": current_sentiment | |
| }) | |
| return { | |
| "token_predictions": horizontal_output, | |
| "aspects": aspects | |
| } | |
| def load_model(model_key): | |
| if model_key in cached_models: | |
| return cached_models[model_key] | |
| base_id = MODEL_OPTIONS[model_key]["base"] | |
| adapter_id = MODEL_OPTIONS[model_key]["adapter"] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(adapter_id) | |
| base_model = AutoModelForSeq2SeqLM.from_pretrained(base_id).to(device) | |
| model = PeftModel.from_pretrained(base_model, adapter_id).to(device) | |
| model.eval() | |
| cached_models[model_key] = (tokenizer, model) | |
| return tokenizer, model | |
| def predict_absa(text, model_choice): | |
| if model_choice in ['mT5', 'mBART']: | |
| tokenizer, model = load_model(model_choice) | |
| decoded = infer_t5_prompt(text, tokenizer, model) | |
| elif model_choice == 'Araberta': | |
| decoded = infer_araberta(text) | |
| else: | |
| decoded = {"error": f"Model {model_choice} not supported"} | |
| return decoded | |