import gradio as gr import torch from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM import requests import json from peft import PeftModel from threading import Thread # --- Configuration --- BASE_MODEL_PATH = "algorythmtechnologies/zenith_coder_v1.1" ADAPTER_SUBFOLDER = "checkpoint-300" SERPER_API_KEY = "e43f937b155ec4feafb0458e4a7693b0d4889db4" # --- Model Loading --- # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH) # Load the model base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_PATH, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, ) # Move model to appropriate device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") base_model.to(device) # Load the PEFT adapter from the subfolder in the Hub repository model = PeftModel.from_pretrained(base_model, BASE_MODEL_PATH, subfolder=ADAPTER_SUBFOLDER) model.eval() # --- Web Search Function --- def search(query): """Performs a web search using the Serper API.""" url = "https://google.serper.dev/search" payload = json.dumps({"q": query}) headers = { 'X-API-KEY': SERPER_API_KEY, 'Content-Type': 'application/json' } try: response = requests.request("POST", url, headers=headers, data=payload) response.raise_for_status() results = response.json() return results.get('organic', []) except requests.exceptions.RequestException as e: print(f"Error during web search: {e}") return [] # --- Response Generation --- def generate_response(message, history): """Generates a response from the model, with optional web search.""" full_prompt = "" for user_msg, assistant_msg in history: full_prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n" full_prompt += f"User: {message}\nAssistant:" search_results = None if message.lower().startswith("search for "): search_query = message[len("search for "):] search_results = search(search_query) if search_results: context = " ".join([res.get('snippet', '') for res in search_results[:5]]) full_prompt = f"Based on the following search results: {context}\n\nUser: {message}\nAssistant:" inputs = tokenizer(full_prompt, return_tensors="pt").to(device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text yield generated_text # --- Gradio UI --- with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as demo: gr.Markdown("# Zenith") gr.ChatInterface( generate_response, chatbot=gr.Chatbot( height=600, avatar_images=(None, "https://i.imgur.com/9kAC4pG.png"), bubble_full_width=False, ), textbox=gr.Textbox( placeholder="Ask me anything or type 'search for '...", container=False, scale=7, ), theme="soft", title=None, submit_btn="Send", ) if __name__ == "__main__": demo.launch()