Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM | |
| import requests | |
| import json | |
| from peft import PeftModel | |
| from threading import Thread | |
| # --- Configuration --- | |
| BASE_MODEL_PATH = "algorythmtechnologies/zenith_coder_v1.1" | |
| ADAPTER_SUBFOLDER = "checkpoint-300" | |
| SERPER_API_KEY = "e43f937b155ec4feafb0458e4a7693b0d4889db4" | |
| # --- Model Loading --- | |
| # Load the tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH) | |
| # Load the model | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL_PATH, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| # Move model to appropriate device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| base_model.to(device) | |
| # Load the PEFT adapter from the subfolder in the Hub repository | |
| model = PeftModel.from_pretrained(base_model, BASE_MODEL_PATH, subfolder=ADAPTER_SUBFOLDER) | |
| model.eval() | |
| # --- Web Search Function --- | |
| def search(query): | |
| """Performs a web search using the Serper API.""" | |
| url = "https://google.serper.dev/search" | |
| payload = json.dumps({"q": query}) | |
| headers = { | |
| 'X-API-KEY': SERPER_API_KEY, | |
| 'Content-Type': 'application/json' | |
| } | |
| try: | |
| response = requests.request("POST", url, headers=headers, data=payload) | |
| response.raise_for_status() | |
| results = response.json() | |
| return results.get('organic', []) | |
| except requests.exceptions.RequestException as e: | |
| print(f"Error during web search: {e}") | |
| return [] | |
| # --- Response Generation --- | |
| def generate_response(message, history): | |
| """Generates a response from the model, with optional web search.""" | |
| full_prompt = "" | |
| for user_msg, assistant_msg in history: | |
| full_prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n" | |
| full_prompt += f"User: {message}\nAssistant:" | |
| search_results = None | |
| if message.lower().startswith("search for "): | |
| search_query = message[len("search for "):] | |
| search_results = search(search_query) | |
| if search_results: | |
| context = " ".join([res.get('snippet', '') for res in search_results[:5]]) | |
| full_prompt = f"Based on the following search results: {context}\n\nUser: {message}\nAssistant:" | |
| inputs = tokenizer(full_prompt, return_tensors="pt").to(device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| generated_text = "" | |
| for new_text in streamer: | |
| generated_text += new_text | |
| yield generated_text | |
| # --- Gradio UI --- | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as demo: | |
| gr.Markdown("# Zenith") | |
| gr.ChatInterface( | |
| generate_response, | |
| chatbot=gr.Chatbot( | |
| height=600, | |
| avatar_images=(None, "https://i.imgur.com/9kAC4pG.png"), | |
| bubble_full_width=False, | |
| ), | |
| textbox=gr.Textbox( | |
| placeholder="Ask me anything or type 'search for <your query>'...", | |
| container=False, | |
| scale=7, | |
| ), | |
| theme="soft", | |
| title=None, | |
| submit_btn="Send", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |