Spaces:
OrbitMC
/
Configuration error

OrbitMC commited on
Commit
8732d46
·
verified ·
1 Parent(s): 22f16dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -28
app.py CHANGED
@@ -1,55 +1,105 @@
1
  import os
2
  import torch
3
  import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- from duckduckgo_search import DDGS
 
6
 
7
  # --- CONFIG ---
8
  MODEL_ID = "google/gemma-3-270m-it"
9
  HF_TOKEN = os.getenv('HF_TOKEN')
10
 
11
- print("--- LOADING GEMMA 3 WITH SEARCH ABILITIES ---")
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
13
  model = AutoModelForCausalLM.from_pretrained(
14
- MODEL_ID, device_map="cpu", dtype="auto", low_cpu_mem_usage=True, trust_remote_code=True, token=HF_TOKEN
 
 
 
 
 
15
  )
16
 
17
- def get_web_context(query):
18
- """Fetch the top 3 search results from DuckDuckGo."""
19
  results = []
20
  try:
21
  with DDGS() as ddgs:
22
  for r in ddgs.text(query, max_results=3):
23
  results.append(f"Source: {r['href']}\nContent: {r['body']}")
24
  except Exception as e:
25
- print(f"Search error: {e}")
26
  return "\n\n".join(results)
27
 
28
- def chat_with_search(message, history):
29
- # 1. Get real-time info
30
- print(f"Searching the web for: {message}")
31
- web_data = get_web_context(message)
 
32
 
33
- # 2. Construct a 'RAG' prompt (Retrieval-Augmented Generation)
34
- prompt = f"""
35
- You are an AI assistant with web access.
36
- Use the following search results to answer the user's question accurately.
37
 
38
- SEARCH RESULTS:
39
- {web_data}
40
 
41
- USER QUESTION: {message}
42
- ANSWER:"""
 
 
 
 
 
 
43
 
44
- inputs = tokenizer(prompt, return_tensors="pt")
45
- with torch.no_grad():
46
- outputs = model.generate(**inputs, max_new_tokens=2048)
 
 
 
 
 
 
 
 
 
 
47
 
48
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
- # Extract only the final answer
50
- return response.split("ANSWER:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Launching
53
- demo = gr.ChatInterface(fn=chat_with_search)
54
  if __name__ == "__main__":
55
- demo.launch(server_name="0.0.0.0", share=True)
 
1
  import os
2
  import torch
3
  import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
5
+ from ddgs import DDGS # Updated package name
6
+ from threading import Thread
7
 
8
  # --- CONFIG ---
9
  MODEL_ID = "google/gemma-3-270m-it"
10
  HF_TOKEN = os.getenv('HF_TOKEN')
11
 
12
+ # --- MODEL LOADING ---
13
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
14
  model = AutoModelForCausalLM.from_pretrained(
15
+ MODEL_ID,
16
+ device_map="cpu",
17
+ torch_dtype=torch.float32, # CPU is more stable with float32
18
+ low_cpu_mem_usage=True,
19
+ trust_remote_code=True,
20
+ token=HF_TOKEN
21
  )
22
 
23
+ def web_search(query):
 
24
  results = []
25
  try:
26
  with DDGS() as ddgs:
27
  for r in ddgs.text(query, max_results=3):
28
  results.append(f"Source: {r['href']}\nContent: {r['body']}")
29
  except Exception as e:
30
+ return f"Search failed: {e}"
31
  return "\n\n".join(results)
32
 
33
+ def stream_chat(message, history, search_enabled, max_tokens, temperature):
34
+ # 1. Handle Web Search
35
+ context = ""
36
+ if search_enabled:
37
+ context = f"\n\nWEB SEARCH RESULTS:\n{web_search(message)}"
38
 
39
+ # 2. Prepare Prompt
40
+ full_prompt = f"Context: {context}\n\nUser: {message}\nAssistant:"
41
+ inputs = tokenizer(full_prompt, return_tensors="pt").to("cpu")
 
42
 
43
+ # 3. Setup Streamer (This fixes the "Freezing" issue)
44
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
45
 
46
+ generate_kwargs = dict(
47
+ **inputs,
48
+ streamer=streamer,
49
+ max_new_tokens=max_tokens,
50
+ do_sample=True,
51
+ temperature=temperature,
52
+ top_p=0.9,
53
+ )
54
 
55
+ # Run generation in a separate thread so UI stays responsive
56
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
57
+ t.start()
58
+
59
+ # Yield tokens one by one
60
+ partial_message = ""
61
+ for new_token in streamer:
62
+ partial_message += new_token
63
+ yield partial_message
64
+
65
+ # --- GRADIO UI ---
66
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
67
+ gr.Markdown("# 🚀 Gemma 3 Ultra Bot (CPU Optimized)")
68
 
69
+ with gr.Row():
70
+ with gr.Column(scale=4):
71
+ chatbot = gr.Chatbot(height=500)
72
+ msg = gr.Textbox(placeholder="Ask me anything...", label="Input")
73
+ with gr.Row():
74
+ submit = gr.Button("Send", variant="primary")
75
+ clear = gr.Button("Clear Chat")
76
+
77
+ with gr.Column(scale=1):
78
+ gr.Markdown("### ⚙️ Settings")
79
+ search_toggle = gr.Checkbox(label="Enable Web Search", value=False)
80
+ token_slider = gr.Slider(minimum=64, maximum=1024, value=256, step=64, label="Max New Tokens")
81
+ temp_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature")
82
+ gr.Markdown("---")
83
+ gr.Info("Note: Generation on CPU may take 10-30 seconds. Streaming is enabled to show progress.")
84
+
85
+ # Link components
86
+ def user(user_message, history):
87
+ return "", history + [[user_message, None]]
88
+
89
+ def bot(history, search_on, tokens, temp):
90
+ user_message = history[-1][0]
91
+ history[-1][1] = ""
92
+ for character in stream_chat(user_message, history, search_on, tokens, temp):
93
+ history[-1][1] = character
94
+ yield history
95
+
96
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
97
+ bot, [chatbot, search_toggle, token_slider, temp_slider], chatbot
98
+ )
99
+ submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
100
+ bot, [chatbot, search_toggle, token_slider, temp_slider], chatbot
101
+ )
102
+ clear.click(lambda: None, None, chatbot, queue=False)
103
 
 
 
104
  if __name__ == "__main__":
105
+ demo.queue().launch(server_name="0.0.0.0")