Spaces:
Sleeping
Sleeping
| import os | |
| import urllib | |
| import requests | |
| from bs4 import BeautifulSoup | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import logging | |
| import feedparser | |
| # Set up logging | |
| logging.basicConfig(level=logging.DEBUG) | |
| logger = logging.getLogger(__name__) | |
| # Define device and load model and tokenizer | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct" | |
| # Load model and tokenizer | |
| try: | |
| logger.debug("Attempting to load the model and tokenizer") | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| logger.debug("Model and tokenizer loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Error loading model and tokenizer: {e}") | |
| model = None | |
| tokenizer = None | |
| # Function to fetch news from Google News RSS feed | |
| def fetch_news(term, num_results=2): | |
| logger.debug(f"Fetching news for term: {term}") | |
| encoded_term = urllib.parse.quote(term) | |
| url = f"https://news.google.com/rss/search?q={encoded_term}" | |
| feed = feedparser.parse(url) | |
| results = [] | |
| for entry in feed.entries[:num_results]: | |
| results.append({"link": entry.link, "text": entry.title}) | |
| logger.debug(f"Fetched news results: {results}") | |
| return results | |
| # Function to perform a Google search and return the results | |
| def search(term, num_results=2, lang="en", timeout=5, safe="active", ssl_verify=None): | |
| logger.debug(f"Starting search for term: {term}") | |
| escaped_term = urllib.parse.quote_plus(term) | |
| start = 0 | |
| all_results = [] | |
| max_chars_per_page = 8000 | |
| with requests.Session() as session: | |
| while start < num_results: | |
| try: | |
| resp = session.get( | |
| url="https://www.google.com/search", | |
| headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"}, | |
| params={ | |
| "q": term, | |
| "num": num_results - start, | |
| "hl": lang, | |
| "start": start, | |
| "safe": safe, | |
| }, | |
| timeout=timeout, | |
| verify=ssl_verify, | |
| ) | |
| resp.raise_for_status() | |
| soup = BeautifulSoup(resp.text, "html.parser") | |
| result_block = soup.find_all("div", attrs={"class": "g"}) | |
| if not result_block: | |
| start += 1 | |
| continue | |
| for result in result_block: | |
| link = result.find("a", href=True) | |
| if link: | |
| link = link["href"] | |
| try: | |
| webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"}) | |
| webpage.raise_for_status() | |
| visible_text = extract_text_from_webpage(webpage.text) | |
| if len(visible_text) > max_chars_per_page: | |
| visible_text = visible_text[:max_chars_per_page] + "..." | |
| all_results.append({"link": link, "text": visible_text}) | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Error fetching or processing {link}: {e}") | |
| all_results.append({"link": link, "text": None}) | |
| else: | |
| all_results.append({"link": None, "text": None}) | |
| start += len(result_block) | |
| except Exception as e: | |
| logger.error(f"Error during search: {e}") | |
| break | |
| logger.debug(f"Search results: {all_results}") | |
| return all_results | |
| # Function to extract visible text from HTML content | |
| def extract_text_from_webpage(html_content): | |
| soup = BeautifulSoup(html_content, "html.parser") | |
| for tag in soup(["script", "style", "header", "footer", "nav"]): | |
| tag.extract() | |
| visible_text = soup.get_text(strip=True) | |
| return visible_text | |
| # Function to format the prompt for the language model | |
| def format_prompt(user_prompt, chat_history): | |
| logger.debug(f"Formatting prompt with user prompt: {user_prompt} and chat history: {chat_history}") | |
| prompt = "" | |
| for item in chat_history: | |
| prompt += f"User: {item[0]}\nAssistant: {item[1]}\n" | |
| prompt += f"User: {user_prompt}\nAssistant:" | |
| logger.debug(f"Formatted prompt: {prompt}") | |
| return prompt | |
| # Function for model inference | |
| def model_inference( | |
| user_prompt, | |
| chat_history, | |
| web_search, | |
| temperature, | |
| max_new_tokens, | |
| repetition_penalty, | |
| top_p, | |
| tokenizer # Pass tokenizer as an argument | |
| ): | |
| logger.debug(f"Starting model inference with user prompt: {user_prompt}, chat history: {chat_history}, web_search: {web_search}") | |
| if not isinstance(user_prompt, dict): | |
| logger.error("Invalid input format. Expected a dictionary.") | |
| return "Invalid input format. Expected a dictionary." | |
| if "files" not in user_prompt: | |
| user_prompt["files"] = [] | |
| if not user_prompt["files"]: | |
| if web_search: | |
| logger.debug("Performing news search") | |
| news_results = fetch_news(user_prompt["text"]) | |
| news2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in news_results]) | |
| formatted_prompt = format_prompt(f"{user_prompt['text']} [NEWS] {news2}", chat_history) | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE) | |
| if model: | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| else: | |
| response = "Model is not available. Please try again later." | |
| logger.debug(f"Model response: {response}") | |
| return response | |
| else: | |
| formatted_prompt = format_prompt(user_prompt["text"], chat_history) | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE) | |
| if model: | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| else: | |
| response = "Model is not available. Please try again later." | |
| logger.debug(f"Model response: {response}") | |
| return response | |
| else: | |
| return "Image input not supported in this implementation." | |
| # Define Gradio interface components | |
| max_new_tokens = gr.Slider( | |
| minimum=1, | |
| maximum=16000, | |
| value=2048, | |
| step=64, | |
| interactive=True, | |
| label="Maximum number of new tokens to generate", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| minimum=0.01, | |
| maximum=5.0, | |
| value=1, | |
| step=0.01, | |
| interactive=True, | |
| label="Repetition penalty", | |
| info="1.0 is equivalent to no penalty", | |
| ) | |
| decoding_strategy = gr.Radio( | |
| [ | |
| "Greedy", | |
| "Top P Sampling", | |
| ], | |
| value="Top P Sampling", | |
| label="Decoding strategy", | |
| interactive=True, | |
| info="Higher values are equivalent to sampling more low-probability tokens.", | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.5, | |
| step=0.05, | |
| visible=True, | |
| interactive=True, | |
| label="Sampling temperature", | |
| info="Higher values will produce more diverse outputs.", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.01, | |
| maximum=0.99, | |
| value=0.9, | |
| step=0.01, | |
| visible=True, | |
| interactive=True, | |
| label="Top P", | |
| info="Higher values are equivalent to sampling more low-probability tokens.", | |
| ) | |
| # Create a chatbot interface | |
| chatbot = gr.Chatbot( | |
| label="OpenGPT-4o-Chatty", | |
| show_copy_button=True, | |
| likeable=True, | |
| layout="panel" | |
| ) | |
| # Define Gradio interface | |
| def chat_interface(user_input, history, web_search, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p): | |
| # Ensure the tokenizer is accessible within the function scope | |
| global tokenizer | |
| # Wrap the user input in a dictionary as expected by the model_inference function | |
| user_prompt = {"text": user_input, "files": []} | |
| # Perform model inference | |
| response = model_inference( | |
| user_prompt=user_prompt, | |
| chat_history=history, | |
| web_search=web_search, | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| top_p=top_p, | |
| tokenizer=tokenizer # Pass tokenizer to the model_inference function | |
| ) | |
| # Update history with the user input and model response | |
| history.append((user_input, response)) | |
| # Return the response and updated history | |
| return response, history | |
| # Define the Gradio interface components | |
| interface = gr.Interface( | |
| fn=chat_interface, | |
| inputs=[ | |
| gr.Textbox(label="User Input", placeholder="Type your message here..."), | |
| gr.State([]), # Initialize the chat history as an empty list | |
| gr.Checkbox(label="Perform Web Search"), | |
| gr.Radio(["Greedy", "Top P Sampling"], label="Decoding strategy"), | |
| gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label="Sampling temperature", value=0.5), | |
| gr.Slider(minimum=1, maximum=16000, step=64, label="Maximum number of new tokens to generate", value=2048), | |
| gr.Slider(minimum=0.01, maximum=5.0, step=0.01, label="Repetition penalty", value=1), | |
| gr.Slider(minimum=0.01, maximum=0.99, step=0.01, label="Top P", value=0.9) | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Assistant Response"), | |
| gr.State([]) # Update the chat history | |
| ], | |
| live=True | |
| ) | |
| # Launch the Gradio interface | |
| interface.launch() |