File size: 1,976 Bytes
3ef1aec
a8ba867
3304f90
a8ba867
3304f90
3ef1aec
 
0567ace
a8ba867
0567ace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8ba867
 
3ef1aec
0567ace
 
 
 
 
3ef1aec
1e9f3ed
 
 
 
 
3304f90
a8ba867
4a0c31e
a8ba867
3304f90
 
 
 
 
 
4a0c31e
3304f90
 
 
 
0567ace
3304f90
3ef1aec
4a0c31e
 
3304f90
 
2eb97c5
3304f90
3ef1aec
3304f90
a8ba867
 
 
 
4a0c31e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import gradio as gr
import torch

MODEL_ID = "SatyamSinghal/taskmind-1.1b-chat-lora"
HF_TOKEN = os.getenv("HF_TOKEN")

pipe = None

def load_model():
    global pipe
    if pipe is not None:
        return

    from peft import AutoPeftModelForCausalLM
    from transformers import AutoTokenizer, pipeline

    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID,
        token=HF_TOKEN,
    )

    print("Loading model...")
    model = AutoPeftModelForCausalLM.from_pretrained(
        MODEL_ID,
        token=HF_TOKEN,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        low_cpu_mem_usage=True,
    )

    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
    )
    print("Model loaded successfully.")


def respond(message, history):
    try:
        load_model()
    except Exception as e:
        return f"❌ Model failed to load: {str(e)}"

    messages = []
    for user_msg, assistant_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    messages.append({"role": "user", "content": message})

    result = pipe(
        messages,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
    )

    generated = result[0]["generated_text"]
    if isinstance(generated, list):
        return generated[-1]["content"]
    return str(generated)


demo = gr.ChatInterface(
    fn=respond,
    title="TaskMind Interface",
    description="Chat with the TaskMind LoRA model.",
    examples=[
        "Who are you?",
        "@Satyam fix the growstreams deck ASAP NO Delay",
        "done bhai, merged the PR",
        "login page 60% ho gaya",
        "getting 500 error on registration",
    ],
)

if __name__ == "__main__":
    demo.launch()