File size: 3,370 Bytes
13e0d30
b303751
4983a55
 
5f82807
b303751
4983a55
26ca86a
4983a55
 
26ca86a
4983a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13e0d30
 
4983a55
 
 
 
 
 
13e0d30
 
4983a55
13e0d30
 
 
 
 
4983a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b303751
 
13e0d30
b303751
 
 
 
 
 
 
 
13e0d30
4983a55
b303751
 
13e0d30
4983a55
b303751
 
 
13e0d30
 
4983a55
13e0d30
 
 
 
b303751
4983a55
 
 
 
 
b303751
13e0d30
b303751
 
 
13e0d30
 
 
 
 
 
1927a13
e73a348
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import torch
from typing import List, Dict
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# 🔹 Change this if it's just an adapter, otherwise leave as is
checkpoint = "tarun7r/Finance-Llama-8B"

# Try to load tokenizer
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# 🔹 Try loading model directly
try:
    model = AutoModelForCausalLM.from_pretrained(
        checkpoint,
        device_map="auto",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    )
except Exception as e:
    # If it's actually a PEFT adapter, load base model first
    print("Direct load failed, trying as PEFT adapter...", e)
    base_model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-8b-hf",
        device_map="auto",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    )
    model = PeftModel.from_pretrained(base_model, checkpoint)


def respond(
    message: str,
    history: List[Dict[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    """
    Chatbot response function with fallback if chat_template is missing.
    """
    messages = [{"role": "system", "content": system_message}]
    messages.extend(history)
    messages.append({"role": "user", "content": message})

    # 🔹 Try using chat template if available
    try:
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
    except Exception:
        # fallback manual formatting
        text = f"{system_message}\n"
        for turn in history:
            text += f"{turn['role'].capitalize()}: {turn['content']}\n"
        text += f"User: {message}\nAssistant:"

    # Tokenize inputs
    inputs = tokenizer(text, return_tensors="pt").to(model.device)

    response = ""
    with torch.no_grad():
        generated = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
        )

    # Decode only new tokens
    new_tokens = generated[0][inputs["input_ids"].shape[-1]:]
    decoded = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

    # Yield token by token (streaming)
    for token in decoded.split():
        response += token + " "
        yield response.strip()


# 🔹 Gradio ChatInterface
chatbot = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(
            value=(
                "You are Marduk (ماردوك), a Financial Assistant made in Iraq for FinTech Hackathon by Makers. "
                "You explain markets clearly and give simple professional career advice "
                "to both people and businesses. You speak English and Arabic Only."
            ),
            label="System message",
        ),
        gr.Slider(minimum=1, maximum=1024, value=500, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.9, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
    ],
)

with gr.Blocks() as demo:
    chatbot.render()

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)