algorythmtechnologies's picture
Upload app.py with huggingface_hub
d9a7e49 verified
raw
history blame
3.4 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM
import requests
import json
from peft import PeftModel
from threading import Thread
# --- Configuration ---
BASE_MODEL_PATH = "algorythmtechnologies/zenith_coder_v1.1"
ADAPTER_SUBFOLDER = "checkpoint-300"
SERPER_API_KEY = "e43f937b155ec4feafb0458e4a7693b0d4889db4"
# --- Model Loading ---
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
# Load the model
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_PATH,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
)
# Move model to appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model.to(device)
# Load the PEFT adapter from the subfolder in the Hub repository
model = PeftModel.from_pretrained(base_model, BASE_MODEL_PATH, subfolder=ADAPTER_SUBFOLDER)
model.eval()
# --- Web Search Function ---
def search(query):
"""Performs a web search using the Serper API."""
url = "https://google.serper.dev/search"
payload = json.dumps({"q": query})
headers = {
'X-API-KEY': SERPER_API_KEY,
'Content-Type': 'application/json'
}
try:
response = requests.request("POST", url, headers=headers, data=payload)
response.raise_for_status()
results = response.json()
return results.get('organic', [])
except requests.exceptions.RequestException as e:
print(f"Error during web search: {e}")
return []
# --- Response Generation ---
def generate_response(message, history):
"""Generates a response from the model, with optional web search."""
full_prompt = ""
for user_msg, assistant_msg in history:
full_prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
full_prompt += f"User: {message}\nAssistant:"
search_results = None
if message.lower().startswith("search for "):
search_query = message[len("search for "):]
search_results = search(search_query)
if search_results:
context = " ".join([res.get('snippet', '') for res in search_results[:5]])
full_prompt = f"Based on the following search results: {context}\n\nUser: {message}\nAssistant:"
inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as demo:
gr.Markdown("# Zenith")
gr.ChatInterface(
generate_response,
chatbot=gr.Chatbot(
height=600,
avatar_images=(None, "https://i.imgur.com/9kAC4pG.png"),
bubble_full_width=False,
),
textbox=gr.Textbox(
placeholder="Ask me anything or type 'search for <your query>'...",
container=False,
scale=7,
),
theme="soft",
title=None,
submit_btn="Send",
)
if __name__ == "__main__":
demo.launch()