| import os |
| from typing import Dict, List, Optional, Tuple |
|
|
| import gradio as gr |
| from huggingface_hub import InferenceClient |
| from tavily import TavilyClient |
|
|
| from config import ( |
| HTML_SYSTEM_PROMPT, GENERIC_SYSTEM_PROMPT, HTML_SYSTEM_PROMPT_WITH_SEARCH, |
| GENERIC_SYSTEM_PROMPT_WITH_SEARCH, FollowUpSystemPrompt |
| ) |
| from chat_processing import ( |
| history_to_messages, messages_to_history, |
| remove_code_block, apply_search_replace_changes, send_to_sandbox, |
| history_to_chatbot_messages, get_gradio_language |
| ) |
| from file_processing import ( |
| extract_text_from_file, create_multimodal_message, |
| ) |
| from web_extraction import extract_website_content, enhance_query_with_search |
|
|
| |
| HF_TOKEN = os.getenv('HF_TOKEN') |
| GROQ_API_KEY = os.getenv('GROQ_API_KEY') |
| FIREWORKS_API_KEY = os.getenv('FIREWORKS_API_KEY') |
|
|
| def get_inference_client(model_id): |
| """Return an InferenceClient configured for Hugging Face, Groq, or Fireworks AI.""" |
| if model_id == "moonshotai/Kimi-K2-Instruct": |
| return InferenceClient( |
| base_url="https://api.groq.com/openai/v1", |
| api_key=GROQ_API_KEY |
| ) |
| elif model_id.startswith("fireworks/"): |
| return InferenceClient( |
| base_url="https://api.fireworks.ai/inference/v1", |
| api_key=FIREWORKS_API_KEY |
| ) |
| else: |
| return InferenceClient( |
| model=model_id, |
| api_key=HF_TOKEN |
| ) |
|
|
| |
| TAVILY_API_KEY = os.getenv('TAVILY_API_KEY') |
| tavily_client = None |
| if TAVILY_API_KEY: |
| try: |
| tavily_client = TavilyClient(api_key=TAVILY_API_KEY) |
| except Exception as e: |
| print(f"Failed to initialize Tavily client: {e}") |
| tavily_client = None |
|
|
| def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optional[str], website_url: Optional[str], _setting: Dict[str, str], _history: Optional[List[Tuple[str, str]]], _current_model: Dict, enable_search: bool = False, language: str = "html"): |
| if query is None: |
| query = '' |
| if _history is None: |
| _history = [] |
|
|
| |
| has_existing_html = False |
| if _history: |
| |
| last_assistant_msg = _history[-1][1] if len(_history) > 0 else "" |
| if '<!DOCTYPE html>' in last_assistant_msg or '<html' in last_assistant_msg: |
| has_existing_html = True |
|
|
| |
| if has_existing_html: |
| |
| system_prompt = FollowUpSystemPrompt |
| else: |
| |
| if language == "html": |
| system_prompt = HTML_SYSTEM_PROMPT_WITH_SEARCH if enable_search else HTML_SYSTEM_PROMPT |
| else: |
| system_prompt = GENERIC_SYSTEM_PROMPT_WITH_SEARCH.format(language=language) if enable_search else GENERIC_SYSTEM_PROMPT.format(language=language) |
|
|
| messages = history_to_messages(_history, system_prompt) |
|
|
| |
| file_text = "" |
| if file: |
| file_text = extract_text_from_file(file) |
| if file_text: |
| file_text = file_text[:5000] |
| query = f"{query}\n\n[Reference file content below]\n{file_text}" |
|
|
| |
| website_text = "" |
| if website_url and website_url.strip(): |
| website_text = extract_website_content(website_url.strip()) |
| if website_text and not website_text.startswith("Error"): |
| website_text = website_text[:8000] |
| query = f"{query}\n\n[Website content to redesign below]\n{website_text}" |
| elif website_text.startswith("Error"): |
| |
| fallback_guidance = """ |
| Since I couldn't extract the website content, please provide additional details about what you'd like to build: |
| 1. What type of website is this? (e.g., e-commerce, blog, portfolio, dashboard) |
| 2. What are the main features you want? |
| 3. What's the target audience? |
| 4. Any specific design preferences? (colors, style, layout) |
| This will help me create a better design for you.""" |
| query = f"{query}\n\n[Error extracting website: {website_text}]{fallback_guidance}" |
|
|
| |
| enhanced_query = enhance_query_with_search(query, enable_search) |
|
|
| |
| client = get_inference_client(_current_model["id"]) |
|
|
| if image is not None: |
| messages.append(create_multimodal_message(enhanced_query, image)) |
| else: |
| messages.append({'role': 'user', 'content': enhanced_query}) |
| try: |
| completion = client.chat.completions.create( |
| model=_current_model["id"], |
| messages=messages, |
| stream=True, |
| max_tokens=5000 |
| ) |
| content = "" |
| for chunk in completion: |
| if chunk.choices[0].delta.content: |
| content += chunk.choices[0].delta.content |
| clean_code = remove_code_block(content) |
| if has_existing_html: |
| |
| if not (clean_code.strip().startswith("<!DOCTYPE html>") or clean_code.strip().startswith("<html")): |
| last_html = _history[-1][1] if _history else "" |
| modified_html = apply_search_replace_changes(last_html, clean_code) |
| clean_code = remove_code_block(modified_html) |
|
|
| yield ( |
| gr.update(value=clean_code, language=get_gradio_language(language)), |
| _history, |
| send_to_sandbox(clean_code) if language == "html" else "<div style='padding:1em;color:#888;text-align:center;'>Preview is only available for HTML.</div>", |
| history_to_chatbot_messages(_history) |
| ) |
| |
| _history = messages_to_history(messages + [{'role': 'assistant', 'content': content}]) |
| final_code = remove_code_block(content) |
| yield ( |
| final_code, |
| _history, |
| send_to_sandbox(final_code), |
| history_to_chatbot_messages(_history), |
| ) |
|
|
| except Exception as e: |
| error_message = f"Error: {str(e)}" |
| yield (error_message, _history, None, history_to_chatbot_messages(_history)) |