Spaces:
Runtime error
Runtime error
File size: 3,397 Bytes
d9a7e49 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | import gradio as gr
import torch
from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM
import requests
import json
from peft import PeftModel
from threading import Thread
# --- Configuration ---
BASE_MODEL_PATH = "algorythmtechnologies/zenith_coder_v1.1"
ADAPTER_SUBFOLDER = "checkpoint-300"
SERPER_API_KEY = "e43f937b155ec4feafb0458e4a7693b0d4889db4"
# --- Model Loading ---
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
# Load the model
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_PATH,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
)
# Move model to appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model.to(device)
# Load the PEFT adapter from the subfolder in the Hub repository
model = PeftModel.from_pretrained(base_model, BASE_MODEL_PATH, subfolder=ADAPTER_SUBFOLDER)
model.eval()
# --- Web Search Function ---
def search(query):
"""Performs a web search using the Serper API."""
url = "https://google.serper.dev/search"
payload = json.dumps({"q": query})
headers = {
'X-API-KEY': SERPER_API_KEY,
'Content-Type': 'application/json'
}
try:
response = requests.request("POST", url, headers=headers, data=payload)
response.raise_for_status()
results = response.json()
return results.get('organic', [])
except requests.exceptions.RequestException as e:
print(f"Error during web search: {e}")
return []
# --- Response Generation ---
def generate_response(message, history):
"""Generates a response from the model, with optional web search."""
full_prompt = ""
for user_msg, assistant_msg in history:
full_prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
full_prompt += f"User: {message}\nAssistant:"
search_results = None
if message.lower().startswith("search for "):
search_query = message[len("search for "):]
search_results = search(search_query)
if search_results:
context = " ".join([res.get('snippet', '') for res in search_results[:5]])
full_prompt = f"Based on the following search results: {context}\n\nUser: {message}\nAssistant:"
inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as demo:
gr.Markdown("# Zenith")
gr.ChatInterface(
generate_response,
chatbot=gr.Chatbot(
height=600,
avatar_images=(None, "https://i.imgur.com/9kAC4pG.png"),
bubble_full_width=False,
),
textbox=gr.Textbox(
placeholder="Ask me anything or type 'search for <your query>'...",
container=False,
scale=7,
),
theme="soft",
title=None,
submit_btn="Send",
)
if __name__ == "__main__":
demo.launch() |