File size: 3,123 Bytes
522e0c3
 
507d0d6
 
 
 
 
 
522e0c3
 
507d0d6
 
 
 
522e0c3
507d0d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522e0c3
 
 
507d0d6
 
 
 
 
 
 
 
522e0c3
 
 
507d0d6
 
 
 
 
 
 
 
 
 
 
 
522e0c3
507d0d6
 
 
522e0c3
 
 
507d0d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522e0c3
 
 
507d0d6
 
 
 
522e0c3
507d0d6
 
 
 
 
 
 
 
 
 
 
 
 
522e0c3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    pipeline,
    BitsAndBytesConfig
)
from peft import PeftModel

# ============================================================
# Configuration
# ============================================================

BASE_MODEL = "NousResearch/Llama-2-7b-chat-hf"
ADAPTER = "Suramya/Llama-2-7b-CloudLex-Intent-Detection"

NUM_LABELS = 6  # MUST match training (Buying, Support, Careers, Partnership, Explore, Others)

LABEL_NAMES = [
    "Buying",
    "Support",
    "Careers",
    "Partnership",
    "Explore",
    "Others",
]

# ============================================================
# Quantization config (replaces deprecated load_in_4bit)
# ============================================================

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# ============================================================
# Load model + LoRA adapter
# ============================================================

base_model = AutoModelForSequenceClassification.from_pretrained(
    BASE_MODEL,
    num_labels=NUM_LABELS,              # 🔑 CRITICAL FIX
    device_map="auto",
    quantization_config=bnb_config,
)

model = PeftModel.from_pretrained(
    base_model,
    ADAPTER,
)

tokenizer = AutoTokenizer.from_pretrained(ADAPTER)
tokenizer.pad_token = tokenizer.eos_token

# ============================================================
# Pipeline
# ============================================================

clf = pipeline(
    task="text-classification",
    model=model,
    tokenizer=tokenizer,
    return_all_scores=True
)

# ============================================================
# Inference function
# ============================================================

def predict_intent(message: str):
    if not message or not message.strip():
        return {}

    outputs = clf(message)[0]

    # Map label IDs to human-readable names
    results = {}
    for i, item in enumerate(outputs):
        label_name = LABEL_NAMES[i]
        results[label_name] = float(item["score"])

    return results

# ============================================================
# Gradio UI
# ============================================================

demo = gr.Interface(
    fn=predict_intent,
    inputs=gr.Textbox(
        lines=3,
        placeholder="Type a CloudLex-related message..."
    ),
    outputs=gr.Label(num_top_classes=6),
    title="CloudLex Intent Detection",
    description=(
        "Llama-2-7B fine-tuned with QLoRA for CloudLex intent classification.\n\n"
        "Intents: Buying, Support, Careers, Partnership, Explore, Others"
    ),
    examples=[
        ["I'd like to schedule a demo for our law firm"],
        ["My CloudLex account isn't loading properly"],
        ["Are you hiring software engineers?"],
        ["We want to partner with CloudLex"],
        ["What features does CloudLex offer?"],
        ["Just browsing"]
    ],
)

demo.launch()