algorythmtechnologies commited on
Commit
d9a7e49
·
verified ·
1 Parent(s): 7b33f4a

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +130 -0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM
4
+ import requests
5
+ import json
6
+ from peft import PeftModel
7
+ from threading import Thread
8
+
9
+
10
+
11
+ # --- Configuration ---
12
+
13
+ BASE_MODEL_PATH = "algorythmtechnologies/zenith_coder_v1.1"
14
+
15
+ ADAPTER_SUBFOLDER = "checkpoint-300"
16
+
17
+ SERPER_API_KEY = "e43f937b155ec4feafb0458e4a7693b0d4889db4"
18
+
19
+
20
+
21
+ # --- Model Loading ---
22
+
23
+ # Load the tokenizer
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
26
+
27
+
28
+
29
+ # Load the model
30
+
31
+ base_model = AutoModelForCausalLM.from_pretrained(
32
+
33
+ BASE_MODEL_PATH,
34
+
35
+ trust_remote_code=True,
36
+
37
+ low_cpu_mem_usage=True,
38
+
39
+ torch_dtype=torch.bfloat16,
40
+
41
+ )
42
+
43
+
44
+
45
+ # Move model to appropriate device
46
+
47
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+
49
+ base_model.to(device)
50
+
51
+
52
+
53
+ # Load the PEFT adapter from the subfolder in the Hub repository
54
+
55
+ model = PeftModel.from_pretrained(base_model, BASE_MODEL_PATH, subfolder=ADAPTER_SUBFOLDER)
56
+
57
+ model.eval()
58
+
59
+ # --- Web Search Function ---
60
+ def search(query):
61
+ """Performs a web search using the Serper API."""
62
+ url = "https://google.serper.dev/search"
63
+ payload = json.dumps({"q": query})
64
+ headers = {
65
+ 'X-API-KEY': SERPER_API_KEY,
66
+ 'Content-Type': 'application/json'
67
+ }
68
+ try:
69
+ response = requests.request("POST", url, headers=headers, data=payload)
70
+ response.raise_for_status()
71
+ results = response.json()
72
+ return results.get('organic', [])
73
+ except requests.exceptions.RequestException as e:
74
+ print(f"Error during web search: {e}")
75
+ return []
76
+
77
+ # --- Response Generation ---
78
+ def generate_response(message, history):
79
+ """Generates a response from the model, with optional web search."""
80
+
81
+ full_prompt = ""
82
+ for user_msg, assistant_msg in history:
83
+ full_prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
84
+ full_prompt += f"User: {message}\nAssistant:"
85
+
86
+ search_results = None
87
+ if message.lower().startswith("search for "):
88
+ search_query = message[len("search for "):]
89
+ search_results = search(search_query)
90
+
91
+ if search_results:
92
+ context = " ".join([res.get('snippet', '') for res in search_results[:5]])
93
+ full_prompt = f"Based on the following search results: {context}\n\nUser: {message}\nAssistant:"
94
+
95
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
96
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
97
+
98
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
99
+
100
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
101
+ thread.start()
102
+
103
+ generated_text = ""
104
+ for new_text in streamer:
105
+ generated_text += new_text
106
+ yield generated_text
107
+
108
+
109
+ # --- Gradio UI ---
110
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as demo:
111
+ gr.Markdown("# Zenith")
112
+ gr.ChatInterface(
113
+ generate_response,
114
+ chatbot=gr.Chatbot(
115
+ height=600,
116
+ avatar_images=(None, "https://i.imgur.com/9kAC4pG.png"),
117
+ bubble_full_width=False,
118
+ ),
119
+ textbox=gr.Textbox(
120
+ placeholder="Ask me anything or type 'search for <your query>'...",
121
+ container=False,
122
+ scale=7,
123
+ ),
124
+ theme="soft",
125
+ title=None,
126
+ submit_btn="Send",
127
+ )
128
+
129
+ if __name__ == "__main__":
130
+ demo.launch()