File size: 3,397 Bytes
d9a7e49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import torch
from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM
import requests
import json
from peft import PeftModel
from threading import Thread



# --- Configuration ---

BASE_MODEL_PATH = "algorythmtechnologies/zenith_coder_v1.1"

ADAPTER_SUBFOLDER = "checkpoint-300"

SERPER_API_KEY = "e43f937b155ec4feafb0458e4a7693b0d4889db4"



# --- Model Loading ---

# Load the tokenizer

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)



# Load the model

base_model = AutoModelForCausalLM.from_pretrained(

    BASE_MODEL_PATH,

    trust_remote_code=True,

    low_cpu_mem_usage=True,

    torch_dtype=torch.bfloat16,

)



# Move model to appropriate device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base_model.to(device)



# Load the PEFT adapter from the subfolder in the Hub repository

model = PeftModel.from_pretrained(base_model, BASE_MODEL_PATH, subfolder=ADAPTER_SUBFOLDER)

model.eval()

# --- Web Search Function ---
def search(query):
    """Performs a web search using the Serper API."""
    url = "https://google.serper.dev/search"
    payload = json.dumps({"q": query})
    headers = {
        'X-API-KEY': SERPER_API_KEY,
        'Content-Type': 'application/json'
    }
    try:
        response = requests.request("POST", url, headers=headers, data=payload)
        response.raise_for_status()
        results = response.json()
        return results.get('organic', [])
    except requests.exceptions.RequestException as e:
        print(f"Error during web search: {e}")
        return []

# --- Response Generation ---
def generate_response(message, history):
    """Generates a response from the model, with optional web search."""
    
    full_prompt = ""
    for user_msg, assistant_msg in history:
        full_prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
    full_prompt += f"User: {message}\nAssistant:"

    search_results = None
    if message.lower().startswith("search for "):
        search_query = message[len("search for "):]
        search_results = search(search_query)

    if search_results:
        context = " ".join([res.get('snippet', '') for res in search_results[:5]])
        full_prompt = f"Based on the following search results: {context}\n\nUser: {message}\nAssistant:"

    inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
    
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    generated_text = ""
    for new_text in streamer:
        generated_text += new_text
        yield generated_text


# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as demo:
    gr.Markdown("# Zenith")
    gr.ChatInterface(
        generate_response,
        chatbot=gr.Chatbot(
            height=600,
            avatar_images=(None, "https://i.imgur.com/9kAC4pG.png"),
            bubble_full_width=False,
        ),
        textbox=gr.Textbox(
            placeholder="Ask me anything or type 'search for <your query>'...",
            container=False,
            scale=7,
        ),
        theme="soft",
        title=None,
        submit_btn="Send",
    )

if __name__ == "__main__":
    demo.launch()