TerminalCalm's picture
feat: additional updates for hf compatibility
4aaface verified
import os
import json
import re
import requests
import gradio as gr
import inspect
from ..mcp import video_tools
from gradio.oauth import OAuthToken
from huggingface_hub import InferenceClient
from huggingface_hub.utils import HfHubHTTPError
SAVE_FILE = "save.json"
HF_MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
def save_settings(provider="Ollama", url=None, preferred_model=None):
"""Saves provider-specific settings to the save file."""
settings = {}
if os.path.exists(SAVE_FILE):
try:
with open(SAVE_FILE, 'r') as f:
settings = json.load(f)
except (json.JSONDecodeError, IOError):
settings = {}
settings['last_provider'] = provider
if provider == "Ollama":
if 'endpoints' not in settings or not isinstance(settings['endpoints'], list):
settings['endpoints'] = []
if 'last_active_url' not in settings:
settings['last_active_url'] = ""
if url:
if not (url.startswith("http://") or url.startswith("https://")):
url = "http://" + url
settings['last_active_url'] = url
endpoint_data = next((e for e in settings['endpoints'] if e['url'] == url), None)
if not endpoint_data:
endpoint_data = {'url': url, 'preferred_model': ''}
settings['endpoints'].append(endpoint_data)
if preferred_model is not None:
endpoint_data['preferred_model'] = preferred_model
with open(SAVE_FILE, 'w') as f:
json.dump(settings, f, indent=4)
def load_settings():
"""Loads settings for the last active provider and their respective configs."""
if not os.path.exists(SAVE_FILE):
return "Ollama", "", "", ""
try:
with open(SAVE_FILE, "r") as f:
data = json.load(f)
last_provider = data.get("last_provider", "Ollama")
# Ollama settings
last_url = data.get("last_active_url", "")
ollama_model = ""
preferred_llm_display = ""
if last_url:
endpoint_data = next((e for e in data.get('endpoints', []) if e['url'] == last_url), None)
if endpoint_data:
ollama_model = endpoint_data.get("preferred_model", "")
preferred_llm_display = f"Preferred model set to: {ollama_model}" if ollama_model else ""
return last_provider, last_url, ollama_model, preferred_llm_display
except (json.JSONDecodeError, IOError):
return "Ollama", "", "", ""
def parse_param_size(size_str: str) -> float:
"""Converts a model size string (e.g., '7B') to a float."""
if not isinstance(size_str, str):
return 0.0
size_str = size_str.upper().strip()
if size_str.endswith('B'):
try:
return float(size_str[:-1])
except (ValueError, TypeError):
return 0.0
return 0.0
def get_model_details(url: str, model_name: str) -> dict:
"""Inspects a model's file for details like tool use and parameter size."""
details = {
'name': model_name,
'supports_tool_calling': False,
'parameter_size': 0.0,
'parameter_size_str': "N/A"
}
try:
res = requests.post(
f"{url.rstrip('/')}/api/show",
json={"name": model_name},
timeout=10
)
res.raise_for_status()
data = res.json()
modelfile = data.get("modelfile", "")
keywords = ["tool", "function", "available_tools", "function_call"]
details['supports_tool_calling'] = any(keyword in modelfile.lower() for keyword in keywords)
param_size_str = data.get("details", {}).get("parameter_size", "0B")
details['parameter_size_str'] = param_size_str
details['parameter_size'] = parse_param_size(param_size_str)
return details
except (requests.exceptions.RequestException, json.JSONDecodeError):
return details
def check_ollama_endpoint(url, preferred_model=None):
"""
Checks an Ollama endpoint, gets a list of available models with their capabilities,
sorts them, and updates the UI accordingly.
"""
if not url or not url.strip():
return "Please enter a URL.", gr.update(visible=False), gr.update(visible=False), url
if not (url.startswith("http://") or url.startswith("https://")):
url = "http://" + url
api_url = f"{url.rstrip('/')}/api/tags"
try:
response = requests.get(api_url, timeout=5)
response.raise_for_status()
models_data = response.json().get("models", [])
if not models_data:
return "Connected, but no models found.", gr.update(visible=False), gr.update(visible=False), url
detailed_models = [get_model_details(url, m['name']) for m in models_data]
detailed_models.sort(key=lambda m: (not m['supports_tool_calling'], -m['parameter_size'], m['name']))
save_settings(url=url, provider="Ollama")
dropdown_choices = []
for m in detailed_models:
tool_text = "Tools: Yes" if m['supports_tool_calling'] else "Tools: No"
name_display = f"🛠️ {m['name']}" if m['supports_tool_calling'] else m['name']
dropdown_choices.append(f"{name_display} ({tool_text}, {m['parameter_size_str']})")
status_message = f"Success! Found and sorted {len(detailed_models)} models."
default_choice = dropdown_choices[0] if dropdown_choices else None
if preferred_model:
matching_choice = next((choice for choice in dropdown_choices if preferred_model in choice), None)
if matching_choice:
default_choice = matching_choice
return status_message, gr.update(choices=dropdown_choices, value=default_choice, visible=True), gr.update(visible=True), url
except requests.exceptions.RequestException:
error_message = "Connection Error: Is the address correct and Ollama running?"
return error_message, gr.update(visible=False), gr.update(visible=False), url
def set_preferred_model(model_selection, current_url):
"""Stores the selected model for the current endpoint and saves it."""
if not model_selection:
return "", "No model selected."
model_name = model_selection.split(" (")[0].strip().replace("🛠️ ", "")
save_settings(url=current_url, preferred_model=model_name, provider="Ollama")
return model_name, f"Preferred model for {current_url} set to: {model_name}"
def check_on_load(url, preferred_model):
"""
Wrapper to trigger endpoint check on load if a URL exists,
otherwise sets a neutral status. Also sets the dropdown to the preferred model.
"""
if not url or not url.strip():
return "Enter an endpoint URL and click 'Check' to begin.", gr.update(visible=False), gr.update(visible=False), None, ""
status, dropdown_update, button_update, current_url = check_ollama_endpoint(url, preferred_model)
model_display_text = f"Preferred model set to: {preferred_model}" if preferred_model else ""
return status, dropdown_update, button_update, current_url, model_display_text
def _parse_llm_tool_call(response_text: str):
"""Parses various potential JSON formats for tool calls from an LLM response."""
try:
# Standard JSON format from Ollama
return json.loads(response_text)
except json.JSONDecodeError:
# Handle cases where the model returns a non-JSON string or malformed JSON
# This is a common failure mode, so we try to be flexible.
cleaned_text = response_text.strip()
# Look for the first '{' and the last '}'
start = cleaned_text.find('{')
end = cleaned_text.rfind('}')
if start != -1 and end != -1:
json_str = cleaned_text[start:end+1]
try:
# Attempt to parse the extracted substring
return json.loads(json_str)
except json.JSONDecodeError:
pass # Fall through if this also fails
# Final attempt: regex for common variations if strict parsing fails
# This is brittle but can catch common non-standard outputs.
tool_match = re.search(r'tool(?:_name)?[\'"]?\s*:\s*[\'"](\w+)[\'"]', cleaned_text, re.IGNORECASE)
args_match = re.search(r'arguments|args|tool_input', cleaned_text, re.IGNORECASE)
if tool_match and args_match:
tool_name = tool_match.group(1)
# This part is complex because argument formats vary wildly.
# For simplicity, we'll just return the tool name if we can find it.
# A more robust solution would use more advanced parsing here.
return {"tool": tool_name, "arguments": {}} # Return a simplified dict
return None
def process_video_prompt_ollama(
prompt: str,
video_path: str,
ollama_url: str,
model_name: str
):
"""Sends a prompt to an Ollama model with video tools and executes the response."""
media_output = None
status, debug_info = "Thinking...", "N/A"
raw_response = {}
if not all([prompt, video_path, ollama_url, model_name]):
status = "Error: Missing prompt, video path, or LLM configuration."
return media_output, debug_info, status, json.dumps({"error": status})
api_url = f"{ollama_url.rstrip('/')}/api/chat"
tools = [
{
"type": "function",
"function": {
"name": "getFirstFrame", "description": "Extracts the very first frame from a video file.",
"parameters": {"type": "object", "properties": {"video_path": {"type": "string", "description": "The path to the video file."}}, "required": ["video_path"]}
}
},
{
"type": "function",
"function": {
"name": "getLastFrame", "description": "Extracts the very last frame from a video file.",
"parameters": {"type": "object", "properties": {"video_path": {"type": "string", "description": "The path to the video file."}}, "required": ["video_path"]}
}
},
{
"type": "function",
"function": {
"name": "convert_mp4_to_gif", "description": "Converts a full MP4 video into a high-quality animated GIF.",
"parameters": {"type": "object", "properties": {"video_path": {"type": "string", "description": "The path to the video file."}, "maxResolution": {"type": "integer", "description": "Optional. Max dimension (width/height) for the GIF. Default is 500."}},"required": ["video_path"]}
}
}
]
try:
response = requests.post(
api_url,
json={
"model": model_name,
"messages": [{"role": "user", "content": prompt}],
"tools": tools,
"stream": False,
},
timeout=60,
)
response.raise_for_status()
raw_response = response.json()
print("--- RAW OLLAMA RESPONSE ---")
print(json.dumps(raw_response, indent=2))
print("---------------------------")
message = raw_response.get("message", {})
tool_calls = message.get("tool_calls")
if tool_calls:
status = f"Tool call requested: {len(tool_calls)} call(s)."
# Execute the first tool call
call = tool_calls[0]
function_info = call.get("function", {})
tool_name = function_info.get("name")
arguments = function_info.get("arguments", {})
available_tools = {
"getFirstFrame": video_tools.getFirstFrame,
"getLastFrame": video_tools.getLastFrame,
"convert_mp4_to_gif": video_tools.convert_mp4_to_gif,
}
if tool_name in available_tools:
# Add video_path, as the model doesn't know the temp path
arguments['video_path'] = video_path
status = f"Executing tool: {tool_name}"
media_output = available_tools[tool_name](**arguments)
status = f"Successfully executed {tool_name}."
else:
status = f"Error: Model tried to call an unknown tool: {tool_name}"
else:
status = "The LLM responded without a tool call. Try rephrasing your prompt."
except requests.exceptions.RequestException as e:
status = f"Ollama Connection Error: {e}"
except Exception as e:
status = f"An unexpected error occurred: {e}"
print(f"Error processing prompt: {e}")
final_json_string = json.dumps(raw_response, indent=4)
debug_info = f"PROMPT:\n{prompt}" # Simplified debug info for now
return media_output, debug_info, status, final_json_string
def _create_tool_schema(func):
"""Creates a JSON schema for a function's parameters for use in HF prompts."""
sig = inspect.signature(func)
properties = {}
required = []
for param in sig.parameters.values():
param_type = "string" # Default, can be improved
if param.annotation == int: param_type = "integer"
if param.annotation == float: param_type = "number"
if param.annotation == bool: param_type = "boolean"
properties[param.name] = {"type": param_type, "description": ""} # desc can be improved
if param.default == inspect.Parameter.empty:
required.append(param.name)
return {
"name": func.__name__,
"description": func.__doc__,
"parameters": {"type": "object", "properties": properties, "required": required},
}
def _build_hf_tool_prompt(user_prompt: str, video_path: str, tools: list) -> str:
"""Builds a stricter, more direct prompt for Hugging Face tool-calling models."""
tool_schemas_str = json.dumps(tools, indent=2)
# This prompt provides a specific context (the video path) and gives direct,
# non-negotiable instructions to the model on how to format its response.
prompt_template = f"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an automated agent that controls a video processing system. Your task is to analyze a user's request and call the appropriate function to handle it. The video file you are working with is located at the following path: `{video_path}`.
You have access to the following functions.
<tools>
{tool_schemas_str}
</tools>
Analyze the user's request below. You must select one and only one function to call. Your response must be **only** the JSON for the function call, enclosed in a `<function-call>` block. Do not provide any other text, explanation, or conversation.
For example:
<function-call>
{{
"name": "getFirstFrame",
"arguments": {{
"video_path": "{video_path}"
}}
}}
</function-call><|eot_id|><|start_header_id|>user<|end_header_id|>
{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
return prompt_template.strip()
def process_video_prompt_hf(
prompt: str,
video_path: str,
oauth_token: OAuthToken | None = None
):
"""Sends a prompt to the Hugging Face Inference API and processes the response."""
# Initialize all possible return values at the top.
media_output = None
debug_info, status = "", ""
raw_response_for_json_component = {}
try:
# --- This try block now covers the entire function ---
available_tools_list = [
video_tools.getFirstFrame,
video_tools.getLastFrame,
video_tools.convert_mp4_to_gif
]
available_tools_schema = [_create_tool_schema(func) for func in available_tools_list]
status = "Thinking..."
hf_prompt = _build_hf_tool_prompt(prompt, video_path, available_tools_schema)
debug_info = f"PROMPT:\n{hf_prompt}"
# Initialize client with the OAuth token
client = InferenceClient(token=oauth_token.token if oauth_token else None)
# Make the API call
response = client.chat_completion(
messages=[{"role": "user", "content": hf_prompt}],
model=HF_MODEL_ID,
max_tokens=500,
)
raw_response_text = response.choices[0].message.content
raw_response_for_json_component = {"api_response": raw_response_text}
tool_call = _parse_llm_tool_call(raw_response_text)
if tool_call:
tool_name = tool_call.get("tool") or tool_call.get("name")
# Gradio's UI components don't like receiving complex objects,
# so we'll just return the raw text for now in debug.
# A full implementation would parse 'arguments' and dispatch.
status = f"LLM wants to call tool: '{tool_name}'"
available_tools = {
"getFirstFrame": video_tools.getFirstFrame,
"getLastFrame": video_tools.getLastFrame,
"convert_mp4_to_gif": video_tools.convert_mp4_to_gif
}
if tool_name in available_tools:
tool_func = available_tools[tool_name]
# --- Argument Parsing and Execution ---
# Get the arguments from the tool call
arguments = tool_call.get("arguments", {})
# Always inject the correct video_path
arguments['video_path'] = video_path
# Call the tool with the provided arguments
media_output = tool_func(**arguments)
status = f"Successfully executed tool: {tool_name} with args: {arguments}"
raw_response_for_json_component['tool_execution_result'] = str(media_output)
else:
status = f"Error: LLM wanted to call unknown tool '{tool_name}'"
else:
status = "LLM responded, but did not request a tool call."
except HfHubHTTPError as e:
status = f"Hugging Face API Error: {e}"
# Print the raw server response to the console for debugging
print("--- RAW HUGGING FACE API ERROR RESPONSE ---")
print(e.response.text)
print("-------------------------------------------")
raw_response_for_json_component = {
"error": str(e),
"status_code": e.response.status_code,
"raw_body": e.response.text
}
# Check for common authentication/permission errors
if e.response.status_code == 401:
status += "\n(This could be an invalid token or you may need to accept the model's terms on Hugging Face.)"
elif e.response.status_code == 404:
status += f"\n(Model '{HF_MODEL_ID}' not found. Check the model ID for typos.)"
print(f"HfHubHTTPError in process_video_prompt_hf: {e}")
except Exception as e:
# This will now catch ANY error from the function and print it.
print("--- UNEXPECTED EXCEPTION IN process_video_prompt_hf ---")
import traceback
traceback.print_exc()
print("-----------------------------------------------------")
status = f"An unexpected error occurred: {e}"
raw_response_for_json_component = {"error": str(e), "traceback": traceback.format_exc()}
# The gr.JSON component's postprocess method expects a JSON *string*.
# We must therefore dump our dictionary to a string before returning.
final_json_string = json.dumps(raw_response_for_json_component, indent=4)
return media_output, debug_info, status, final_json_string
def dispatch_video_prompt(
llm_provider: str,
prompt: str,
video_path: str,
# Ollama args
ollama_url: str,
ollama_model: str,
# HF args
oauth_token: OAuthToken | None = None
):
"""
Dispatches the video prompt to the appropriate LLM provider
and returns all necessary outputs for the UI.
"""
# Initialize default return values
media_output = None
debug_info, status = "", ""
raw_response = json.dumps({}) # Default to an empty JSON object string
if llm_provider == "Ollama":
media_output, debug_info, status, raw_response = process_video_prompt_ollama(
prompt, video_path, ollama_url, ollama_model
)
elif llm_provider == "Hugging Face":
if oauth_token is None or not getattr(oauth_token, 'token', None):
status = "Error: Hugging Face token not available. Please log in."
debug_info = "OAuth token is missing or invalid."
raw_response = json.dumps({"error": status}, indent=4)
else:
media_output, debug_info, status, raw_response = process_video_prompt_hf(
prompt, video_path, oauth_token
)
else:
status = f"Error: Unknown LLM provider '{llm_provider}'."
debug_info = "Invalid provider selected."
raw_response = json.dumps({"error": status}, indent=4)
return media_output, debug_info, status, raw_response