import gradio as gr import torch from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, pipeline, BitsAndBytesConfig ) from peft import PeftModel # ============================================================ # Configuration # ============================================================ BASE_MODEL = "NousResearch/Llama-2-7b-chat-hf" ADAPTER = "Suramya/Llama-2-7b-CloudLex-Intent-Detection" NUM_LABELS = 6 # MUST match training (Buying, Support, Careers, Partnership, Explore, Others) LABEL_NAMES = [ "Buying", "Support", "Careers", "Partnership", "Explore", "Others", ] # ============================================================ # Quantization config (replaces deprecated load_in_4bit) # ============================================================ bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) # ============================================================ # Load model + LoRA adapter # ============================================================ base_model = AutoModelForSequenceClassification.from_pretrained( BASE_MODEL, num_labels=NUM_LABELS, # 🔑 CRITICAL FIX device_map="auto", quantization_config=bnb_config, ) model = PeftModel.from_pretrained( base_model, ADAPTER, ) tokenizer = AutoTokenizer.from_pretrained(ADAPTER) tokenizer.pad_token = tokenizer.eos_token # ============================================================ # Pipeline # ============================================================ clf = pipeline( task="text-classification", model=model, tokenizer=tokenizer, return_all_scores=True ) # ============================================================ # Inference function # ============================================================ def predict_intent(message: str): if not message or not message.strip(): return {} outputs = clf(message)[0] # Map label IDs to human-readable names results = {} for i, item in enumerate(outputs): label_name = LABEL_NAMES[i] results[label_name] = float(item["score"]) return results # ============================================================ # Gradio UI # ============================================================ demo = gr.Interface( fn=predict_intent, inputs=gr.Textbox( lines=3, placeholder="Type a CloudLex-related message..." ), outputs=gr.Label(num_top_classes=6), title="CloudLex Intent Detection", description=( "Llama-2-7B fine-tuned with QLoRA for CloudLex intent classification.\n\n" "Intents: Buying, Support, Careers, Partnership, Explore, Others" ), examples=[ ["I'd like to schedule a demo for our law firm"], ["My CloudLex account isn't loading properly"], ["Are you hiring software engineers?"], ["We want to partner with CloudLex"], ["What features does CloudLex offer?"], ["Just browsing"] ], ) demo.launch()