Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -13,7 +13,7 @@ import base64
|
|
| 13 |
import logging
|
| 14 |
import time
|
| 15 |
from urllib.parse import quote # Added for URL encoding
|
| 16 |
-
import importlib #
|
| 17 |
|
| 18 |
import gradio as gr
|
| 19 |
import spaces
|
|
@@ -84,7 +84,6 @@ def generate_image(prompt: str, width: float, height: float, guidance: float, in
|
|
| 84 |
logging.error(f"Image generation failed: {str(e)}")
|
| 85 |
return None, f"Error: {str(e)}"
|
| 86 |
|
| 87 |
-
# Base64 padding fix function
|
| 88 |
def fix_base64_padding(data):
|
| 89 |
"""Fix the padding of a Base64 string."""
|
| 90 |
if isinstance(data, bytes):
|
|
@@ -99,18 +98,12 @@ def fix_base64_padding(data):
|
|
| 99 |
|
| 100 |
return data
|
| 101 |
|
| 102 |
-
# =============================================================================
|
| 103 |
-
# Memory cleanup function
|
| 104 |
-
# =============================================================================
|
| 105 |
def clear_cuda_cache():
|
| 106 |
"""Explicitly clear the CUDA cache."""
|
| 107 |
if torch.cuda.is_available():
|
| 108 |
torch.cuda.empty_cache()
|
| 109 |
gc.collect()
|
| 110 |
|
| 111 |
-
# =============================================================================
|
| 112 |
-
# SerpHouse related functions
|
| 113 |
-
# =============================================================================
|
| 114 |
SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
|
| 115 |
|
| 116 |
def extract_keywords(text: str, top_k: int = 5) -> str:
|
|
@@ -176,9 +169,6 @@ Below are the search results. Use this information to answer the query:
|
|
| 176 |
logger.error(f"Web search failed: {e}")
|
| 177 |
return f"Web search failed: {str(e)}"
|
| 178 |
|
| 179 |
-
# =============================================================================
|
| 180 |
-
# Model and processor loading
|
| 181 |
-
# =============================================================================
|
| 182 |
MAX_CONTENT_CHARS = 2000
|
| 183 |
MAX_INPUT_LENGTH = 2096
|
| 184 |
model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
|
|
@@ -191,9 +181,6 @@ model = Gemma3ForConditionalGeneration.from_pretrained(
|
|
| 191 |
)
|
| 192 |
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
| 193 |
|
| 194 |
-
# =============================================================================
|
| 195 |
-
# CSV, TXT, PDF analysis functions
|
| 196 |
-
# =============================================================================
|
| 197 |
def analyze_csv_file(path: str) -> str:
|
| 198 |
try:
|
| 199 |
df = pd.read_csv(path)
|
|
@@ -238,9 +225,6 @@ def pdf_to_markdown(pdf_path: str) -> str:
|
|
| 238 |
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
|
| 239 |
return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
|
| 240 |
|
| 241 |
-
# =============================================================================
|
| 242 |
-
# Check media file limits
|
| 243 |
-
# =============================================================================
|
| 244 |
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
| 245 |
image_count = 0
|
| 246 |
video_count = 0
|
|
@@ -293,9 +277,6 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
|
| 293 |
return False
|
| 294 |
return True
|
| 295 |
|
| 296 |
-
# =============================================================================
|
| 297 |
-
# Video processing functions
|
| 298 |
-
# =============================================================================
|
| 299 |
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
| 300 |
vidcap = cv2.VideoCapture(video_path)
|
| 301 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
|
@@ -328,9 +309,6 @@ def process_video(video_path: str) -> tuple[list[dict], list[str]]:
|
|
| 328 |
content.append({"type": "image", "url": temp_file.name})
|
| 329 |
return content, temp_files
|
| 330 |
|
| 331 |
-
# =============================================================================
|
| 332 |
-
# Interleaved <image> processing function
|
| 333 |
-
# =============================================================================
|
| 334 |
def process_interleaved_images(message: dict) -> list[dict]:
|
| 335 |
parts = re.split(r"(<image>)", message["text"])
|
| 336 |
content = []
|
|
@@ -347,9 +325,6 @@ def process_interleaved_images(message: dict) -> list[dict]:
|
|
| 347 |
content.append({"type": "text", "text": part})
|
| 348 |
return content
|
| 349 |
|
| 350 |
-
# =============================================================================
|
| 351 |
-
# File processing -> content creation
|
| 352 |
-
# =============================================================================
|
| 353 |
def is_image_file(file_path: str) -> bool:
|
| 354 |
return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
|
| 355 |
|
|
@@ -390,9 +365,6 @@ def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]:
|
|
| 390 |
content_list.append({"type": "image", "url": img_path})
|
| 391 |
return content_list, temp_files
|
| 392 |
|
| 393 |
-
# =============================================================================
|
| 394 |
-
# Convert history to LLM messages
|
| 395 |
-
# =============================================================================
|
| 396 |
def process_history(history: list[dict]) -> list[dict]:
|
| 397 |
messages = []
|
| 398 |
current_user_content = []
|
|
@@ -416,9 +388,6 @@ def process_history(history: list[dict]) -> list[dict]:
|
|
| 416 |
messages.append({"role": "user", "content": current_user_content})
|
| 417 |
return messages
|
| 418 |
|
| 419 |
-
# =============================================================================
|
| 420 |
-
# Model generation function (with OOM catching)
|
| 421 |
-
# =============================================================================
|
| 422 |
def _model_gen_with_oom_catch(**kwargs):
|
| 423 |
try:
|
| 424 |
model.generate(**kwargs)
|
|
@@ -433,18 +402,10 @@ def _model_gen_with_oom_catch(**kwargs):
|
|
| 433 |
def load_function_definitions(json_path="functions.json"):
|
| 434 |
"""
|
| 435 |
로컬 JSON 파일에서 함수 정의 목록을 로드하여 반환.
|
| 436 |
-
각 항목: {
|
| 437 |
-
"name": <str>,
|
| 438 |
-
"description": <str>,
|
| 439 |
-
"module_path": <str>,
|
| 440 |
-
"func_name_in_module": <str>,
|
| 441 |
-
"parameters": { ... }
|
| 442 |
-
}
|
| 443 |
"""
|
| 444 |
try:
|
| 445 |
with open(json_path, "r", encoding="utf-8") as f:
|
| 446 |
data = json.load(f)
|
| 447 |
-
# name을 키로 하는 dict 형태로 재구성
|
| 448 |
func_dict = {}
|
| 449 |
for entry in data:
|
| 450 |
func_name = entry["name"]
|
|
@@ -456,9 +417,6 @@ def load_function_definitions(json_path="functions.json"):
|
|
| 456 |
|
| 457 |
FUNCTION_DEFINITIONS = load_function_definitions("functions.json")
|
| 458 |
|
| 459 |
-
# =============================================================================
|
| 460 |
-
# Dynamic handle_function_call
|
| 461 |
-
# =============================================================================
|
| 462 |
def handle_function_call(text: str) -> str:
|
| 463 |
"""
|
| 464 |
Detects and processes function call blocks in the text using the JSON-based approach.
|
|
@@ -470,7 +428,6 @@ def handle_function_call(text: str) -> str:
|
|
| 470 |
```tool_code
|
| 471 |
get_product_name_by_PID(PID="807ZPKBL9V")
|
| 472 |
```
|
| 473 |
-
We parse that block, check if 'FUNCTION_DEFINITIONS' has an entry, then import & call it.
|
| 474 |
"""
|
| 475 |
import re
|
| 476 |
pattern = r"```tool_code\s*(.*?)\s*```"
|
|
@@ -479,12 +436,11 @@ def handle_function_call(text: str) -> str:
|
|
| 479 |
return ""
|
| 480 |
code_block = match.group(1).strip()
|
| 481 |
|
| 482 |
-
# 함수명 추출 (예: get_stock_price)
|
| 483 |
-
# 정규식: ^(\w+)\(.*\)
|
| 484 |
func_match = re.match(r'^(\w+)\((.*)\)$', code_block)
|
| 485 |
if not func_match:
|
| 486 |
logger.debug("No valid function call format found.")
|
| 487 |
return ""
|
|
|
|
| 488 |
func_name = func_match.group(1)
|
| 489 |
param_str = func_match.group(2).strip()
|
| 490 |
|
|
@@ -496,43 +452,35 @@ def handle_function_call(text: str) -> str:
|
|
| 496 |
func_info = FUNCTION_DEFINITIONS[func_name]
|
| 497 |
module_path = func_info["module_path"]
|
| 498 |
module_func_name = func_info["func_name_in_module"]
|
| 499 |
-
|
| 500 |
try:
|
| 501 |
imported_module = importlib.import_module(module_path)
|
| 502 |
except ImportError as e:
|
| 503 |
logger.error(f"Failed to import module {module_path}: {e}")
|
| 504 |
return f"```tool_output\nError: Cannot import module '{module_path}'\n```"
|
| 505 |
|
| 506 |
-
# 실제 함수 객체를 가져옴
|
| 507 |
if not hasattr(imported_module, module_func_name):
|
| 508 |
logger.error(f"Module '{module_path}' has no attribute '{module_func_name}'.")
|
| 509 |
return f"```tool_output\nError: Function '{module_func_name}' not found in module '{module_path}'\n```"
|
| 510 |
|
| 511 |
real_func = getattr(imported_module, module_func_name)
|
| 512 |
|
| 513 |
-
# 파라미터 파싱
|
| 514 |
-
# 단순 정규식으로 key="value" or key=123 식을 구분
|
| 515 |
param_pattern = r'(\w+)\s*=\s*"(.*?)"|(\w+)\s*=\s*([\d.]+)'
|
| 516 |
-
# 이 정규식은 간단히 key="string" 또는 key=123 같은 형태를 파싱
|
| 517 |
-
# 더 복잡한 경우 별도 파싱 로직이나 json.loads 기법 사용 필요
|
| 518 |
param_dict = {}
|
| 519 |
for p_match in re.finditer(param_pattern, param_str):
|
| 520 |
if p_match.group(1) and p_match.group(2):
|
| 521 |
-
# group(1)은 key, group(2)는 string value
|
| 522 |
key = p_match.group(1)
|
| 523 |
val = p_match.group(2)
|
| 524 |
param_dict[key] = val
|
| 525 |
else:
|
| 526 |
-
# group(3)은 key, group(4)는 numeric value
|
| 527 |
key = p_match.group(3)
|
| 528 |
val = p_match.group(4)
|
| 529 |
-
# 숫자 변환
|
| 530 |
if '.' in val:
|
| 531 |
param_dict[key] = float(val)
|
| 532 |
else:
|
| 533 |
param_dict[key] = int(val)
|
| 534 |
|
| 535 |
-
# 이제 실제 함수 실행
|
| 536 |
try:
|
| 537 |
result = real_func(**param_dict)
|
| 538 |
except Exception as e:
|
|
@@ -541,9 +489,6 @@ def handle_function_call(text: str) -> str:
|
|
| 541 |
|
| 542 |
return f"```tool_output\n{result}\n```"
|
| 543 |
|
| 544 |
-
# =============================================================================
|
| 545 |
-
# Main inference function
|
| 546 |
-
# =============================================================================
|
| 547 |
@spaces.GPU(duration=120)
|
| 548 |
def run(
|
| 549 |
message: dict,
|
|
@@ -555,19 +500,18 @@ def run(
|
|
| 555 |
age_group: str = "20s",
|
| 556 |
mbti_personality: str = "INTP",
|
| 557 |
sexual_openness: int = 2,
|
| 558 |
-
image_gen: bool = False
|
| 559 |
) -> Iterator[str]:
|
| 560 |
if not validate_media_constraints(message, history):
|
| 561 |
yield ""
|
| 562 |
return
|
| 563 |
temp_files = []
|
| 564 |
try:
|
| 565 |
-
# JSON에서 로드된 함수
|
| 566 |
-
# (토큰 부담이 커질 수 있으므로, 적당히 압축 요약 권장)
|
| 567 |
-
# 아래는 예시로 간단히 함수 이름만 나열
|
| 568 |
available_funcs_text = ""
|
| 569 |
for f_name, info in FUNCTION_DEFINITIONS.items():
|
| 570 |
-
|
|
|
|
| 571 |
|
| 572 |
persona = (
|
| 573 |
f"{system_prompt.strip()}\n\n"
|
|
@@ -575,7 +519,9 @@ def run(
|
|
| 575 |
f"Age Group: {age_group}\n"
|
| 576 |
f"MBTI Persona: {mbti_personality}\n"
|
| 577 |
f"Sexual Openness (1-5): {sexual_openness}\n\n"
|
| 578 |
-
"Below are the available functions you can call
|
|
|
|
|
|
|
| 579 |
f"{available_funcs_text}\n"
|
| 580 |
)
|
| 581 |
combined_system_msg = f"[System Prompt]\n{persona.strip()}\n\n"
|
|
@@ -629,7 +575,6 @@ def run(
|
|
| 629 |
output_so_far += new_text
|
| 630 |
yield output_so_far
|
| 631 |
|
| 632 |
-
# 모델 출력 중 ```tool_code``` 블록이 있으면 처리
|
| 633 |
func_result = handle_function_call(output_so_far)
|
| 634 |
if func_result:
|
| 635 |
output_so_far += "\n\n" + func_result
|
|
@@ -652,17 +597,12 @@ def run(
|
|
| 652 |
pass
|
| 653 |
clear_cuda_cache()
|
| 654 |
|
| 655 |
-
# =============================================================================
|
| 656 |
-
# Modified model run function - handles image generation and gallery update
|
| 657 |
-
# =============================================================================
|
| 658 |
def modified_run(message, history, system_prompt, max_new_tokens, use_web_search, web_search_query,
|
| 659 |
age_group, mbti_personality, sexual_openness, image_gen):
|
| 660 |
-
# Initialize and hide the gallery component
|
| 661 |
output_so_far = ""
|
| 662 |
gallery_update = gr.Gallery(visible=False, value=[])
|
| 663 |
yield output_so_far, gallery_update
|
| 664 |
|
| 665 |
-
# Execute the original run function
|
| 666 |
text_generator = run(message, history, system_prompt, max_new_tokens, use_web_search,
|
| 667 |
web_search_query, age_group, mbti_personality, sexual_openness, image_gen)
|
| 668 |
|
|
@@ -670,15 +610,12 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
| 670 |
output_so_far = text_chunk
|
| 671 |
yield output_so_far, gallery_update
|
| 672 |
|
| 673 |
-
# If image generation is enabled and there is text input, update the gallery
|
| 674 |
if image_gen and message["text"].strip():
|
| 675 |
try:
|
| 676 |
width, height = 512, 512
|
| 677 |
guidance, steps, seed = 7.5, 30, 42
|
| 678 |
|
| 679 |
logger.info(f"Calling image generation for gallery with prompt: {message['text']}")
|
| 680 |
-
|
| 681 |
-
# Call the API to generate an image
|
| 682 |
image_result, seed_info = generate_image(
|
| 683 |
prompt=message["text"].strip(),
|
| 684 |
width=width,
|
|
@@ -687,7 +624,6 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
| 687 |
inference_steps=steps,
|
| 688 |
seed=seed
|
| 689 |
)
|
| 690 |
-
|
| 691 |
if image_result:
|
| 692 |
if isinstance(image_result, str) and (
|
| 693 |
image_result.startswith('data:') or
|
|
@@ -699,22 +635,18 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
| 699 |
else:
|
| 700 |
b64data = image_result
|
| 701 |
content_type = "image/webp"
|
| 702 |
-
|
| 703 |
image_bytes = base64.b64decode(b64data)
|
| 704 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
| 705 |
temp_file.write(image_bytes)
|
| 706 |
temp_path = temp_file.name
|
| 707 |
gallery_update = gr.Gallery(visible=True, value=[temp_path])
|
| 708 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
| 709 |
-
|
| 710 |
except Exception as e:
|
| 711 |
logger.error(f"Error processing Base64 image: {e}")
|
| 712 |
yield output_so_far + f"\n\n(Error processing image: {e})", gallery_update
|
| 713 |
-
|
| 714 |
elif isinstance(image_result, str) and os.path.exists(image_result):
|
| 715 |
gallery_update = gr.Gallery(visible=True, value=[image_result])
|
| 716 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
| 717 |
-
|
| 718 |
elif isinstance(image_result, str) and '/tmp/' in image_result:
|
| 719 |
try:
|
| 720 |
client = Client(API_URL)
|
|
@@ -722,13 +654,11 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
| 722 |
prompt=message["text"].strip(),
|
| 723 |
api_name="/generate_base64_image"
|
| 724 |
)
|
| 725 |
-
|
| 726 |
if isinstance(result, str) and (result.startswith('data:') or len(result) > 100):
|
| 727 |
if result.startswith('data:'):
|
| 728 |
content_type, b64data = result.split(';base64,')
|
| 729 |
else:
|
| 730 |
b64data = result
|
| 731 |
-
|
| 732 |
image_bytes = base64.b64decode(b64data)
|
| 733 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
| 734 |
temp_file.write(image_bytes)
|
|
@@ -737,7 +667,6 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
| 737 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
| 738 |
else:
|
| 739 |
yield output_so_far + "\n\n(Image generation failed: Invalid format)", gallery_update
|
| 740 |
-
|
| 741 |
except Exception as e:
|
| 742 |
logger.error(f"Error calling alternative API: {e}")
|
| 743 |
yield output_so_far + f"\n\n(Image generation failed: {e})", gallery_update
|
|
@@ -755,14 +684,10 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
| 755 |
yield output_so_far + f"\n\n(Unsupported image format: {type(image_result)})", gallery_update
|
| 756 |
else:
|
| 757 |
yield output_so_far + f"\n\n(Image generation failed: {seed_info})", gallery_update
|
| 758 |
-
|
| 759 |
except Exception as e:
|
| 760 |
logger.error(f"Error during gallery image generation: {e}")
|
| 761 |
yield output_so_far + f"\n\n(Image generation error: {e})", gallery_update
|
| 762 |
|
| 763 |
-
# =============================================================================
|
| 764 |
-
# Examples
|
| 765 |
-
# =============================================================================
|
| 766 |
examples = [
|
| 767 |
[
|
| 768 |
{
|
|
@@ -855,7 +780,7 @@ examples = [
|
|
| 855 |
],
|
| 856 |
[
|
| 857 |
{
|
| 858 |
-
"text": "AAPL의 현재 주가를 알려줘.",
|
| 859 |
"files": []
|
| 860 |
}
|
| 861 |
],
|
|
|
|
| 13 |
import logging
|
| 14 |
import time
|
| 15 |
from urllib.parse import quote # Added for URL encoding
|
| 16 |
+
import importlib # For dynamic import
|
| 17 |
|
| 18 |
import gradio as gr
|
| 19 |
import spaces
|
|
|
|
| 84 |
logging.error(f"Image generation failed: {str(e)}")
|
| 85 |
return None, f"Error: {str(e)}"
|
| 86 |
|
|
|
|
| 87 |
def fix_base64_padding(data):
|
| 88 |
"""Fix the padding of a Base64 string."""
|
| 89 |
if isinstance(data, bytes):
|
|
|
|
| 98 |
|
| 99 |
return data
|
| 100 |
|
|
|
|
|
|
|
|
|
|
| 101 |
def clear_cuda_cache():
|
| 102 |
"""Explicitly clear the CUDA cache."""
|
| 103 |
if torch.cuda.is_available():
|
| 104 |
torch.cuda.empty_cache()
|
| 105 |
gc.collect()
|
| 106 |
|
|
|
|
|
|
|
|
|
|
| 107 |
SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
|
| 108 |
|
| 109 |
def extract_keywords(text: str, top_k: int = 5) -> str:
|
|
|
|
| 169 |
logger.error(f"Web search failed: {e}")
|
| 170 |
return f"Web search failed: {str(e)}"
|
| 171 |
|
|
|
|
|
|
|
|
|
|
| 172 |
MAX_CONTENT_CHARS = 2000
|
| 173 |
MAX_INPUT_LENGTH = 2096
|
| 174 |
model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
|
|
|
|
| 181 |
)
|
| 182 |
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
| 183 |
|
|
|
|
|
|
|
|
|
|
| 184 |
def analyze_csv_file(path: str) -> str:
|
| 185 |
try:
|
| 186 |
df = pd.read_csv(path)
|
|
|
|
| 225 |
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
|
| 226 |
return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
|
| 227 |
|
|
|
|
|
|
|
|
|
|
| 228 |
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
| 229 |
image_count = 0
|
| 230 |
video_count = 0
|
|
|
|
| 277 |
return False
|
| 278 |
return True
|
| 279 |
|
|
|
|
|
|
|
|
|
|
| 280 |
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
| 281 |
vidcap = cv2.VideoCapture(video_path)
|
| 282 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
|
|
|
| 309 |
content.append({"type": "image", "url": temp_file.name})
|
| 310 |
return content, temp_files
|
| 311 |
|
|
|
|
|
|
|
|
|
|
| 312 |
def process_interleaved_images(message: dict) -> list[dict]:
|
| 313 |
parts = re.split(r"(<image>)", message["text"])
|
| 314 |
content = []
|
|
|
|
| 325 |
content.append({"type": "text", "text": part})
|
| 326 |
return content
|
| 327 |
|
|
|
|
|
|
|
|
|
|
| 328 |
def is_image_file(file_path: str) -> bool:
|
| 329 |
return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
|
| 330 |
|
|
|
|
| 365 |
content_list.append({"type": "image", "url": img_path})
|
| 366 |
return content_list, temp_files
|
| 367 |
|
|
|
|
|
|
|
|
|
|
| 368 |
def process_history(history: list[dict]) -> list[dict]:
|
| 369 |
messages = []
|
| 370 |
current_user_content = []
|
|
|
|
| 388 |
messages.append({"role": "user", "content": current_user_content})
|
| 389 |
return messages
|
| 390 |
|
|
|
|
|
|
|
|
|
|
| 391 |
def _model_gen_with_oom_catch(**kwargs):
|
| 392 |
try:
|
| 393 |
model.generate(**kwargs)
|
|
|
|
| 402 |
def load_function_definitions(json_path="functions.json"):
|
| 403 |
"""
|
| 404 |
로컬 JSON 파일에서 함수 정의 목록을 로드하여 반환.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
"""
|
| 406 |
try:
|
| 407 |
with open(json_path, "r", encoding="utf-8") as f:
|
| 408 |
data = json.load(f)
|
|
|
|
| 409 |
func_dict = {}
|
| 410 |
for entry in data:
|
| 411 |
func_name = entry["name"]
|
|
|
|
| 417 |
|
| 418 |
FUNCTION_DEFINITIONS = load_function_definitions("functions.json")
|
| 419 |
|
|
|
|
|
|
|
|
|
|
| 420 |
def handle_function_call(text: str) -> str:
|
| 421 |
"""
|
| 422 |
Detects and processes function call blocks in the text using the JSON-based approach.
|
|
|
|
| 428 |
```tool_code
|
| 429 |
get_product_name_by_PID(PID="807ZPKBL9V")
|
| 430 |
```
|
|
|
|
| 431 |
"""
|
| 432 |
import re
|
| 433 |
pattern = r"```tool_code\s*(.*?)\s*```"
|
|
|
|
| 436 |
return ""
|
| 437 |
code_block = match.group(1).strip()
|
| 438 |
|
|
|
|
|
|
|
| 439 |
func_match = re.match(r'^(\w+)\((.*)\)$', code_block)
|
| 440 |
if not func_match:
|
| 441 |
logger.debug("No valid function call format found.")
|
| 442 |
return ""
|
| 443 |
+
|
| 444 |
func_name = func_match.group(1)
|
| 445 |
param_str = func_match.group(2).strip()
|
| 446 |
|
|
|
|
| 452 |
func_info = FUNCTION_DEFINITIONS[func_name]
|
| 453 |
module_path = func_info["module_path"]
|
| 454 |
module_func_name = func_info["func_name_in_module"]
|
| 455 |
+
|
| 456 |
try:
|
| 457 |
imported_module = importlib.import_module(module_path)
|
| 458 |
except ImportError as e:
|
| 459 |
logger.error(f"Failed to import module {module_path}: {e}")
|
| 460 |
return f"```tool_output\nError: Cannot import module '{module_path}'\n```"
|
| 461 |
|
|
|
|
| 462 |
if not hasattr(imported_module, module_func_name):
|
| 463 |
logger.error(f"Module '{module_path}' has no attribute '{module_func_name}'.")
|
| 464 |
return f"```tool_output\nError: Function '{module_func_name}' not found in module '{module_path}'\n```"
|
| 465 |
|
| 466 |
real_func = getattr(imported_module, module_func_name)
|
| 467 |
|
| 468 |
+
# 간단 파라미터 파싱 (key="value" or key=123)
|
|
|
|
| 469 |
param_pattern = r'(\w+)\s*=\s*"(.*?)"|(\w+)\s*=\s*([\d.]+)'
|
|
|
|
|
|
|
| 470 |
param_dict = {}
|
| 471 |
for p_match in re.finditer(param_pattern, param_str):
|
| 472 |
if p_match.group(1) and p_match.group(2):
|
|
|
|
| 473 |
key = p_match.group(1)
|
| 474 |
val = p_match.group(2)
|
| 475 |
param_dict[key] = val
|
| 476 |
else:
|
|
|
|
| 477 |
key = p_match.group(3)
|
| 478 |
val = p_match.group(4)
|
|
|
|
| 479 |
if '.' in val:
|
| 480 |
param_dict[key] = float(val)
|
| 481 |
else:
|
| 482 |
param_dict[key] = int(val)
|
| 483 |
|
|
|
|
| 484 |
try:
|
| 485 |
result = real_func(**param_dict)
|
| 486 |
except Exception as e:
|
|
|
|
| 489 |
|
| 490 |
return f"```tool_output\n{result}\n```"
|
| 491 |
|
|
|
|
|
|
|
|
|
|
| 492 |
@spaces.GPU(duration=120)
|
| 493 |
def run(
|
| 494 |
message: dict,
|
|
|
|
| 500 |
age_group: str = "20s",
|
| 501 |
mbti_personality: str = "INTP",
|
| 502 |
sexual_openness: int = 2,
|
| 503 |
+
image_gen: bool = False
|
| 504 |
) -> Iterator[str]:
|
| 505 |
if not validate_media_constraints(message, history):
|
| 506 |
yield ""
|
| 507 |
return
|
| 508 |
temp_files = []
|
| 509 |
try:
|
| 510 |
+
# JSON에서 로드된 함수 정보 문자열화 (예: 함수명과 example_usage만)
|
|
|
|
|
|
|
| 511 |
available_funcs_text = ""
|
| 512 |
for f_name, info in FUNCTION_DEFINITIONS.items():
|
| 513 |
+
example_usage = info.get("example_usage", "")
|
| 514 |
+
available_funcs_text += f"\n\nFunction: {f_name}\nDescription: {info['description']}\nExample:\n{example_usage}\n"
|
| 515 |
|
| 516 |
persona = (
|
| 517 |
f"{system_prompt.strip()}\n\n"
|
|
|
|
| 519 |
f"Age Group: {age_group}\n"
|
| 520 |
f"MBTI Persona: {mbti_personality}\n"
|
| 521 |
f"Sexual Openness (1-5): {sexual_openness}\n\n"
|
| 522 |
+
"Below are the available functions you can call.\n"
|
| 523 |
+
"Important: Use the format exactly like: ```tool_code\nfunctionName(param=\"string\", ...)\n```\n"
|
| 524 |
+
"(Strings must be in double quotes)\n"
|
| 525 |
f"{available_funcs_text}\n"
|
| 526 |
)
|
| 527 |
combined_system_msg = f"[System Prompt]\n{persona.strip()}\n\n"
|
|
|
|
| 575 |
output_so_far += new_text
|
| 576 |
yield output_so_far
|
| 577 |
|
|
|
|
| 578 |
func_result = handle_function_call(output_so_far)
|
| 579 |
if func_result:
|
| 580 |
output_so_far += "\n\n" + func_result
|
|
|
|
| 597 |
pass
|
| 598 |
clear_cuda_cache()
|
| 599 |
|
|
|
|
|
|
|
|
|
|
| 600 |
def modified_run(message, history, system_prompt, max_new_tokens, use_web_search, web_search_query,
|
| 601 |
age_group, mbti_personality, sexual_openness, image_gen):
|
|
|
|
| 602 |
output_so_far = ""
|
| 603 |
gallery_update = gr.Gallery(visible=False, value=[])
|
| 604 |
yield output_so_far, gallery_update
|
| 605 |
|
|
|
|
| 606 |
text_generator = run(message, history, system_prompt, max_new_tokens, use_web_search,
|
| 607 |
web_search_query, age_group, mbti_personality, sexual_openness, image_gen)
|
| 608 |
|
|
|
|
| 610 |
output_so_far = text_chunk
|
| 611 |
yield output_so_far, gallery_update
|
| 612 |
|
|
|
|
| 613 |
if image_gen and message["text"].strip():
|
| 614 |
try:
|
| 615 |
width, height = 512, 512
|
| 616 |
guidance, steps, seed = 7.5, 30, 42
|
| 617 |
|
| 618 |
logger.info(f"Calling image generation for gallery with prompt: {message['text']}")
|
|
|
|
|
|
|
| 619 |
image_result, seed_info = generate_image(
|
| 620 |
prompt=message["text"].strip(),
|
| 621 |
width=width,
|
|
|
|
| 624 |
inference_steps=steps,
|
| 625 |
seed=seed
|
| 626 |
)
|
|
|
|
| 627 |
if image_result:
|
| 628 |
if isinstance(image_result, str) and (
|
| 629 |
image_result.startswith('data:') or
|
|
|
|
| 635 |
else:
|
| 636 |
b64data = image_result
|
| 637 |
content_type = "image/webp"
|
|
|
|
| 638 |
image_bytes = base64.b64decode(b64data)
|
| 639 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
| 640 |
temp_file.write(image_bytes)
|
| 641 |
temp_path = temp_file.name
|
| 642 |
gallery_update = gr.Gallery(visible=True, value=[temp_path])
|
| 643 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
|
|
|
| 644 |
except Exception as e:
|
| 645 |
logger.error(f"Error processing Base64 image: {e}")
|
| 646 |
yield output_so_far + f"\n\n(Error processing image: {e})", gallery_update
|
|
|
|
| 647 |
elif isinstance(image_result, str) and os.path.exists(image_result):
|
| 648 |
gallery_update = gr.Gallery(visible=True, value=[image_result])
|
| 649 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
|
|
|
| 650 |
elif isinstance(image_result, str) and '/tmp/' in image_result:
|
| 651 |
try:
|
| 652 |
client = Client(API_URL)
|
|
|
|
| 654 |
prompt=message["text"].strip(),
|
| 655 |
api_name="/generate_base64_image"
|
| 656 |
)
|
|
|
|
| 657 |
if isinstance(result, str) and (result.startswith('data:') or len(result) > 100):
|
| 658 |
if result.startswith('data:'):
|
| 659 |
content_type, b64data = result.split(';base64,')
|
| 660 |
else:
|
| 661 |
b64data = result
|
|
|
|
| 662 |
image_bytes = base64.b64decode(b64data)
|
| 663 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
| 664 |
temp_file.write(image_bytes)
|
|
|
|
| 667 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
| 668 |
else:
|
| 669 |
yield output_so_far + "\n\n(Image generation failed: Invalid format)", gallery_update
|
|
|
|
| 670 |
except Exception as e:
|
| 671 |
logger.error(f"Error calling alternative API: {e}")
|
| 672 |
yield output_so_far + f"\n\n(Image generation failed: {e})", gallery_update
|
|
|
|
| 684 |
yield output_so_far + f"\n\n(Unsupported image format: {type(image_result)})", gallery_update
|
| 685 |
else:
|
| 686 |
yield output_so_far + f"\n\n(Image generation failed: {seed_info})", gallery_update
|
|
|
|
| 687 |
except Exception as e:
|
| 688 |
logger.error(f"Error during gallery image generation: {e}")
|
| 689 |
yield output_so_far + f"\n\n(Image generation error: {e})", gallery_update
|
| 690 |
|
|
|
|
|
|
|
|
|
|
| 691 |
examples = [
|
| 692 |
[
|
| 693 |
{
|
|
|
|
| 780 |
],
|
| 781 |
[
|
| 782 |
{
|
| 783 |
+
"text": "AAPL의 현재 주가를 알려줘.",
|
| 784 |
"files": []
|
| 785 |
}
|
| 786 |
],
|