File size: 3,729 Bytes
1923dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import threading

import gradio as gr
from huggingface_hub import hf_hub_download
from llama_cpp import Llama


MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "EREN121232/MAJESTIC-FIN-R1-gguf")
MODEL_FILENAME = os.getenv("MODEL_FILENAME", "MAJESTIC-FIN-R1-Q8_0.gguf")
MODEL_LABEL = os.getenv("MODEL_LABEL", "MAJESTIC-FIN-R1 Q8_0")
N_CTX = int(os.getenv("N_CTX", "4096"))
N_THREADS = int(os.getenv("CPU_CORES", os.getenv("N_THREADS", str(os.cpu_count() or 2))))

_MODEL = None
_MODEL_LOCK = threading.Lock()
_INFER_LOCK = threading.Lock()


def get_model() -> Llama:
    global _MODEL
    with _MODEL_LOCK:
        if _MODEL is None:
            model_path = hf_hub_download(
                repo_id=MODEL_REPO_ID,
                filename=MODEL_FILENAME,
            )
            _MODEL = Llama(
                model_path=model_path,
                n_ctx=N_CTX,
                n_threads=N_THREADS,
                n_gpu_layers=0,
                verbose=False,
            )
    return _MODEL


def generate(prompt: str, system_prompt: str, temperature: float, max_tokens: int, top_p: float, repeat_penalty: float) -> str:
    prompt = prompt.strip()
    system_prompt = system_prompt.strip()

    if not prompt:
        return "Please enter a prompt."

    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": prompt})

    llm = get_model()
    with _INFER_LOCK:
        response = llm.create_chat_completion(
            messages=messages,
            temperature=float(temperature),
            max_tokens=int(max_tokens),
            top_p=float(top_p),
            repeat_penalty=float(repeat_penalty),
        )

    return response["choices"][0]["message"]["content"].strip()


with gr.Blocks(title="MAJESTIC FIN R1 Free API") as demo:
    gr.Markdown(
        f"""
        # MAJESTIC FIN R1 Free API

        Public CPU deployment for `{MODEL_LABEL}` backed by `llama-cpp-python`.
        The API endpoint name is `/chat`.
        """
    )

    prompt = gr.Textbox(
        label="Prompt",
        lines=8,
        placeholder="Ask about finance, markets, accounting, or your fine-tuned task.",
    )
    output = gr.Textbox(label="Response", lines=14)

    with gr.Accordion("Generation Settings", open=False):
        system_prompt = gr.Textbox(
            label="System Prompt",
            lines=4,
            value="You are MAJESTIC-FIN-R1, a helpful finance-focused assistant.",
        )
        temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
        max_tokens = gr.Slider(64, 1024, value=256, step=32, label="Max Tokens")
        top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top P")
        repeat_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.05, label="Repeat Penalty")

    run_button = gr.Button("Generate", variant="primary")

    gr.Examples(
        examples=[
            ["Summarize the key risks in a company's balance sheet."],
            ["Explain EBITDA vs free cash flow in simple terms."],
            ["Give a short market outlook for a cautious investor."],
        ],
        inputs=prompt,
    )

    run_button.click(
        fn=generate,
        inputs=[prompt, system_prompt, temperature, max_tokens, top_p, repeat_penalty],
        outputs=output,
        api_name="chat",
        show_progress="minimal",
        concurrency_limit=1,
    )

    prompt.submit(
        fn=generate,
        inputs=[prompt, system_prompt, temperature, max_tokens, top_p, repeat_penalty],
        outputs=output,
        show_progress="minimal",
        concurrency_limit=1,
    )


if __name__ == "__main__":
    demo.queue(max_size=16).launch()