Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from googlesearch import search | |
| import requests | |
| from bs4 import BeautifulSoup | |
| # Load the model and tokenizer | |
| model_name = "akjindal53244/Llama-3.1-Storm-8B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| def fetch_web_content(url): | |
| try: | |
| response = requests.get(url, timeout=10) | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| return ' '.join(p.get_text() for p in soup.find_all('p')) | |
| except Exception as e: | |
| print(f"Error fetching {url}: {str(e)}") | |
| return "Could not fetch content from this URL" | |
| def web_search(query, num_results=3): | |
| try: | |
| results = [] | |
| for j in search(query, num_results=num_results, advanced=True): | |
| content = fetch_web_content(j.url) | |
| results.append({ | |
| "title": j.title, | |
| "url": j.url, | |
| "content": content[:1000] # Limit content length | |
| }) | |
| return results | |
| except Exception as e: | |
| print(f"Search error: {str(e)}") | |
| return [] | |
| def generate_text(prompt, max_length, temperature, use_web): | |
| if use_web: | |
| search_results = web_search(prompt) | |
| context = "\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in search_results]) | |
| prompt = f"Web Context:\n{context}\n\nUser Query: {prompt}" | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant with web search capabilities."}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_length, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_k=100, | |
| top_p=0.95, | |
| ) | |
| return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
| # CSS and UI components | |
| css = """ | |
| :root { | |
| --primary: #e94560; | |
| --secondary: #1a1a2e; | |
| --background: #16213e; | |
| --text: #e0e0e0; | |
| } | |
| body { | |
| background-color: var(--background); | |
| color: var(--text); | |
| font-family: 'Inter', sans-serif; | |
| } | |
| .container { | |
| max-width: 1200px; | |
| margin: auto; | |
| padding: 20px; | |
| } | |
| .gradio-container { | |
| background-color: var(--background); | |
| border-radius: 15px; | |
| box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3); | |
| } | |
| .header { | |
| background: linear-gradient(135deg, #0f3460 0%, #1a1a2e 100%); | |
| padding: 2rem; | |
| border-radius: 15px 15px 0 0; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .header h1 { | |
| color: var(--primary); | |
| font-size: 2.8rem; | |
| margin-bottom: 1rem; | |
| font-weight: 700; | |
| } | |
| .input-group, .output-group { | |
| background-color: var(--secondary); | |
| padding: 2rem; | |
| border-radius: 12px; | |
| margin-bottom: 2rem; | |
| border: 1px solid #2d2d4d; | |
| } | |
| .generate-btn { | |
| background: linear-gradient(135deg, var(--primary) 0%, #c81e45 100%) !important; | |
| color: white !important; | |
| border-radius: 8px !important; | |
| padding: 12px 28px !important; | |
| } | |
| .example-prompts ul { | |
| grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); | |
| gap: 1rem; | |
| } | |
| """ | |
| example_prompts = [ | |
| "Explain quantum computing in simple terms", | |
| "Latest developments in AI research", | |
| "Compare React and Vue.js frameworks", | |
| "Recent advancements in cancer treatment" | |
| ] | |
| with gr.Blocks(css=css, theme=gr.themes.Default()) as iface: | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>Llama-3.1-Storm-8B AI Assistant</h1> | |
| <p>Enhanced with real-time web search capabilities</p> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| with gr.TabItem("Chat Assistant"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| with gr.Group(elem_classes="example-prompts"): | |
| gr.Markdown("## Example Queries") | |
| example_btns = [gr.Button(prompt) for prompt in example_prompts] | |
| with gr.Group(elem_classes="input-group"): | |
| prompt = gr.Textbox(label="Your Query", placeholder="Enter your question...", lines=5) | |
| with gr.Row(): | |
| web_search_toggle = gr.Checkbox(label="Enable Web Search", value=False) | |
| num_results = gr.Slider(1, 5, value=3, step=1, label="Search Results") | |
| with gr.Row(): | |
| max_length = gr.Slider(32, 1024, value=256, step=32, label="Response Length") | |
| temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Creativity") | |
| generate_btn = gr.Button("Generate Response", elem_classes="generate-btn") | |
| with gr.Column(scale=2): | |
| with gr.Group(elem_classes="output-group"): | |
| output = gr.Textbox(label="Generated Response", lines=12) | |
| with gr.Row(): | |
| copy_btn = gr.Button("Copy") | |
| clear_btn = gr.Button("Clear") | |
| with gr.TabItem("Web Results"): | |
| web_results = gr.JSON(label="Search Results Preview") | |
| # Event handlers | |
| generate_btn.click( | |
| generate_text, | |
| inputs=[prompt, max_length, temperature, web_search_toggle], | |
| outputs=output | |
| ).then( | |
| lambda q: web_search(q) if q else [], | |
| inputs=[prompt], | |
| outputs=web_results | |
| ) | |
| for btn in example_btns: | |
| btn.click(lambda x: x, inputs=[btn], outputs=[prompt]) | |
| copy_btn.click( | |
| None, | |
| inputs=[output], | |
| _js="(text) => { navigator.clipboard.writeText(text); return [] }" | |
| ) | |
| clear_btn.click(lambda: "", outputs=[output]) | |
| iface.launch() |