Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import re | |
| import base64 | |
| import glob | |
| import logging | |
| import random | |
| import shutil | |
| import time | |
| import zipfile | |
| import json | |
| import asyncio | |
| import aiofiles | |
| import toml | |
| from datetime import datetime | |
| from collections import Counter | |
| from dataclasses import dataclass, field | |
| from io import BytesIO | |
| from typing import Optional, List, Dict, Any | |
| import pandas as pd | |
| import pytz | |
| import streamlit as st | |
| from PIL import Image, ImageDraw | |
| from reportlab.pdfgen import canvas | |
| from reportlab.lib.utils import ImageReader | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, PageBreak | |
| from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle | |
| from reportlab.lib.enums import TA_JUSTIFY | |
| import fitz | |
| import requests | |
| try: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoModelForVision2Seq, pipeline | |
| _transformers_available = True | |
| except ImportError: | |
| _transformers_available = False | |
| st.sidebar.warning("AI/ML libraries (torch, transformers) not found. Local model features disabled.") | |
| try: | |
| from diffusers import StableDiffusionPipeline | |
| _diffusers_available = True | |
| except ImportError: | |
| _diffusers_available = False | |
| if _transformers_available: | |
| st.sidebar.warning("Diffusers library not found. Diffusion model features disabled.") | |
| try: | |
| from openai import OpenAI | |
| _openai_available = True | |
| except ImportError: | |
| _openai_available = False | |
| st.sidebar.warning("OpenAI library not found. OpenAI model features disabled.") | |
| from huggingface_hub import InferenceClient, HfApi, list_models | |
| from huggingface_hub.utils import RepositoryNotFoundError, GatedRepoError | |
| # --- App Configuration --- | |
| st.set_page_config( | |
| page_title="Vision & Layout Titans ππΌοΈ", | |
| page_icon="π€", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| menu_items={ | |
| 'Get Help': 'https://huggingface.co/docs', | |
| 'Report a Bug': None, | |
| 'About': "Combined App: Image/MD->PDF Layout + AI-Powered Tools π" | |
| } | |
| ) | |
| # --- Secrets Management --- | |
| try: | |
| secrets = toml.load(".streamlit/secrets.toml") if os.path.exists(".streamlit/secrets.toml") else {} | |
| HF_TOKEN = secrets.get("HF_TOKEN", os.getenv("HF_TOKEN", "")) | |
| OPENAI_API_KEY = secrets.get("OPENAI_API_KEY", os.getenv("OPENAI_API_KEY", "")) | |
| except Exception as e: | |
| st.error(f"Error loading secrets: {e}") | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") | |
| if not HF_TOKEN: | |
| st.sidebar.warning("Hugging Face token not found in secrets or environment. Some features may be limited.") | |
| if not OPENAI_API_KEY and _openai_available: | |
| st.sidebar.warning("OpenAI API key not found in secrets or environment. OpenAI features disabled.") | |
| # --- Logging Setup --- | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| log_records = [] | |
| class LogCaptureHandler(logging.Handler): | |
| def emit(self, record): | |
| log_records.append(record) | |
| logger.addHandler(LogCaptureHandler()) | |
| # --- Model Initialization --- | |
| DEFAULT_PROVIDER = "hf-inference" | |
| FEATURED_MODELS_LIST = [ | |
| "meta-llama/Meta-Llama-3.1-8B-Instruct", | |
| "mistralai/Mistral-7B-Instruct-v0.3", | |
| "google/gemma-2-9b-it", | |
| "Qwen/Qwen2-7B-Instruct", | |
| "microsoft/Phi-3-mini-4k-instruct", | |
| "HuggingFaceH4/zephyr-7b-beta", | |
| "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", | |
| "HuggingFaceTB/SmolLM-1.7B-Instruct" | |
| ] | |
| VISION_MODELS_LIST = [ | |
| "Salesforce/blip-image-captioning-large", | |
| "microsoft/trocr-large-handwritten", | |
| "llava-hf/llava-1.5-7b-hf", | |
| "google/vit-base-patch16-224" | |
| ] | |
| DIFFUSION_MODELS_LIST = [ | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| "runwayml/stable-diffusion-v1-5", | |
| "OFA-Sys/small-stable-diffusion-v0" | |
| ] | |
| OPENAI_MODELS_LIST = [ | |
| "gpt-4o", | |
| "gpt-4-turbo", | |
| "gpt-3.5-turbo", | |
| "text-davinci-003" | |
| ] | |
| st.session_state.setdefault('local_models', {}) | |
| st.session_state.setdefault('hf_inference_client', None) | |
| st.session_state.setdefault('openai_client', None) | |
| if _openai_available and OPENAI_API_KEY: | |
| try: | |
| st.session_state['openai_client'] = OpenAI(api_key=OPENAI_API_KEY) | |
| logger.info("OpenAI client initialized successfully.") | |
| except Exception as e: | |
| st.error(f"Failed to initialize OpenAI client: {e}") | |
| logger.error(f"OpenAI client initialization failed: {e}") | |
| st.session_state['openai_client'] = None | |
| # --- Session State Initialization --- | |
| st.session_state.setdefault('layout_snapshots', []) | |
| st.session_state.setdefault('layout_new_uploads', []) | |
| st.session_state.setdefault('history', []) | |
| st.session_state.setdefault('processing', {}) | |
| st.session_state.setdefault('asset_checkboxes', {'image': {}, 'md': {}, 'pdf': {}}) | |
| st.session_state.setdefault('downloaded_pdfs', {}) | |
| st.session_state.setdefault('unique_counter', 0) | |
| st.session_state.setdefault('cam0_file', None) | |
| st.session_state.setdefault('cam1_file', None) | |
| st.session_state.setdefault('characters', []) | |
| st.session_state.setdefault('char_form_reset_key', 0) | |
| st.session_state.setdefault('gallery_size', 10) | |
| st.session_state.setdefault('hf_provider', DEFAULT_PROVIDER) | |
| st.session_state.setdefault('hf_custom_key', "") | |
| st.session_state.setdefault('hf_selected_api_model', FEATURED_MODELS_LIST[0]) | |
| st.session_state.setdefault('hf_custom_api_model', "") | |
| st.session_state.setdefault('openai_selected_model', OPENAI_MODELS_LIST[0] if _openai_available else "") | |
| st.session_state.setdefault('selected_local_model_path', None) | |
| st.session_state.setdefault('gen_max_tokens', 512) | |
| st.session_state.setdefault('gen_temperature', 0.7) | |
| st.session_state.setdefault('gen_top_p', 0.95) | |
| st.session_state.setdefault('gen_frequency_penalty', 0.0) | |
| if 'asset_gallery_container' not in st.session_state: | |
| st.session_state['asset_gallery_container'] = {'image': st.sidebar.empty(), 'md': st.sidebar.empty(), 'pdf': st.sidebar.empty()} | |
| # --- Dataclasses --- | |
| class LocalModelConfig: | |
| name: str | |
| hf_id: str | |
| model_type: str | |
| size_category: str = "unknown" | |
| domain: Optional[str] = None | |
| local_path: str = field(init=False) | |
| def __post_init__(self): | |
| type_folder = f"{self.model_type}_models" | |
| safe_name = re.sub(r'[^\w\-]+', '_', self.name) | |
| self.local_path = os.path.join(type_folder, safe_name) | |
| def get_full_path(self): | |
| return os.path.abspath(self.local_path) | |
| class DiffusionConfig: | |
| name: str | |
| base_model: str | |
| size: str | |
| domain: Optional[str] = None | |
| def model_path(self): | |
| return f"diffusion_models/{self.name}" | |
| # --- Helper Functions --- | |
| def generate_filename(sequence, ext="png"): | |
| timestamp = time.strftime('%Y%m%d_%H%M%S') | |
| safe_sequence = re.sub(r'[^\w\-]+', '_', str(sequence)) | |
| return f"{safe_sequence}_{timestamp}.{ext}" | |
| def pdf_url_to_filename(url): | |
| name = re.sub(r'^https?://', '', url) | |
| name = re.sub(r'[<>:"/\\|?*]', '_', name) | |
| return name[:100] + ".pdf" | |
| def get_download_link(file_path, mime_type="application/octet-stream", label="Download"): | |
| if not os.path.exists(file_path): | |
| return f"{label} (File not found)" | |
| try: | |
| with open(file_path, "rb") as f: | |
| file_bytes = f.read() | |
| b64 = base64.b64encode(file_bytes).decode() | |
| return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label}</a>' | |
| except Exception as e: | |
| logger.error(f"Error creating download link for {file_path}: {e}") | |
| return f"{label} (Error)" | |
| def zip_directory(directory_path, zip_path): | |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
| for root, _, files in os.walk(directory_path): | |
| for file in files: | |
| file_path = os.path.join(root, file) | |
| zipf.write(file_path, os.path.relpath(file_path, os.path.dirname(directory_path))) | |
| def get_local_model_paths(model_type="causal"): | |
| pattern = f"{model_type}_models/*" | |
| dirs = [d for d in glob.glob(pattern) if os.path.isdir(d)] | |
| return dirs | |
| def get_gallery_files(file_types=("png", "pdf", "jpg", "jpeg", "md", "txt")): | |
| all_files = set() | |
| for ext in file_types: | |
| all_files.update(glob.glob(f"*.{ext.lower()}")) | |
| all_files.update(glob.glob(f"*.{ext.upper()}")) | |
| return sorted([f for f in all_files if os.path.basename(f).lower() != 'readme.md']) | |
| def get_typed_gallery_files(file_type): | |
| if file_type == 'image': | |
| return get_gallery_files(('png', 'jpg', 'jpeg')) | |
| elif file_type == 'md': | |
| return get_gallery_files(('md',)) | |
| elif file_type == 'pdf': | |
| return get_gallery_files(('pdf',)) | |
| return [] | |
| def download_pdf(url, output_path): | |
| try: | |
| headers = {'User-Agent': 'Mozilla/5.0'} | |
| response = requests.get(url, stream=True, timeout=20, headers=headers) | |
| response.raise_for_status() | |
| with open(output_path, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| logger.info(f"Successfully downloaded {url} to {output_path}") | |
| return True | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to download {url}: {e}") | |
| if os.path.exists(output_path): | |
| try: | |
| os.remove(output_path) | |
| except: | |
| pass | |
| return False | |
| except Exception as e: | |
| logger.error(f"An unexpected error occurred during download of {url}: {e}") | |
| if os.path.exists(output_path): | |
| try: | |
| os.remove(output_path) | |
| except: | |
| pass | |
| return False | |
| async def process_pdf_snapshot(pdf_path, mode="single", resolution_factor=2.0): | |
| start_time = time.time() | |
| status_placeholder = st.empty() | |
| status_placeholder.text(f"Processing PDF Snapshot ({mode}, Res: {resolution_factor}x)... (0s)") | |
| output_files = [] | |
| try: | |
| doc = fitz.open(pdf_path) | |
| matrix = fitz.Matrix(resolution_factor, resolution_factor) | |
| num_pages_to_process = min(1, len(doc)) if mode == "single" else min(2, len(doc)) if mode == "twopage" else len(doc) | |
| for i in range(num_pages_to_process): | |
| page_start_time = time.time() | |
| page = doc[i] | |
| pix = page.get_pixmap(matrix=matrix) | |
| base_name = os.path.splitext(os.path.basename(pdf_path))[0] | |
| output_file = generate_filename(f"{base_name}_pg{i+1}_{mode}", "png") | |
| await asyncio.to_thread(pix.save, output_file) | |
| output_files.append(output_file) | |
| elapsed_page = int(time.time() - page_start_time) | |
| status_placeholder.text(f"Processing PDF Snapshot ({mode}, Res: {resolution_factor}x)... Page {i+1}/{num_pages_to_process} done ({elapsed_page}s)") | |
| await asyncio.sleep(0.01) | |
| doc.close() | |
| elapsed = int(time.time() - start_time) | |
| status_placeholder.success(f"PDF Snapshot ({mode}, {len(output_files)} files) completed in {elapsed}s!") | |
| return output_files | |
| except Exception as e: | |
| logger.error(f"Failed to process PDF snapshot for {pdf_path}: {e}") | |
| status_placeholder.error(f"Failed to process PDF {os.path.basename(pdf_path)}: {e}") | |
| for f in output_files: | |
| if os.path.exists(f): | |
| os.remove(f) | |
| return [] | |
| def get_hf_client() -> Optional[InferenceClient]: | |
| provider = st.session_state.hf_provider | |
| custom_key = st.session_state.hf_custom_key.strip() | |
| token_to_use = custom_key if custom_key else HF_TOKEN | |
| if not token_to_use and provider != "hf-inference": | |
| st.error(f"Provider '{provider}' requires a Hugging Face API token.") | |
| return None | |
| if provider == "hf-inference" and not token_to_use: | |
| logger.warning("Using hf-inference provider without a token. Rate limits may apply.") | |
| token_to_use = None | |
| current_client = st.session_state.get('hf_inference_client') | |
| needs_reinit = True | |
| if current_client: | |
| client_uses_custom = hasattr(current_client, '_token') and current_client._token == custom_key | |
| client_uses_default = hasattr(current_client, '_token') and current_client._token == HF_TOKEN | |
| client_uses_no_token = not hasattr(current_client, '_token') or current_client._token is None | |
| if current_client.provider == provider: | |
| if custom_key and client_uses_custom: | |
| needs_reinit = False | |
| elif not custom_key and HF_TOKEN and client_uses_default: | |
| needs_reinit = False | |
| elif not custom_key and not HF_TOKEN and client_uses_no_token: | |
| needs_reinit = False | |
| if needs_reinit: | |
| try: | |
| logger.info(f"Initializing InferenceClient for provider: {provider}.") | |
| st.session_state.hf_inference_client = InferenceClient(token=token_to_use, provider=provider) | |
| logger.info("InferenceClient initialized successfully.") | |
| except Exception as e: | |
| st.error(f"Failed to initialize Hugging Face client: {e}") | |
| logger.error(f"InferenceClient initialization failed: {e}") | |
| st.session_state.hf_inference_client = None | |
| return st.session_state.hf_inference_client | |
| def process_text_hf(text: str, prompt: str, use_api: bool, model_id: str = None) -> str: | |
| status_placeholder = st.empty() | |
| start_time = time.time() | |
| result_text = "" | |
| params = { | |
| "max_new_tokens": st.session_state.gen_max_tokens, | |
| "temperature": st.session_state.gen_temperature, | |
| "top_p": st.session_state.gen_top_p, | |
| "repetition_penalty": st.session_state.gen_frequency_penalty + 1.0, | |
| } | |
| seed = st.session_state.gen_seed | |
| if seed != -1: | |
| params["seed"] = seed | |
| system_prompt = "You are a helpful assistant. Process the following text based on the user's request." | |
| full_prompt = f"{prompt}\n\n---\n\n{text}" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": full_prompt} | |
| ] | |
| if use_api: | |
| status_placeholder.info("Processing text using Hugging Face API...") | |
| client = get_hf_client() | |
| if not client: | |
| return "Error: Hugging Face client not available." | |
| model_id = model_id or st.session_state.hf_custom_api_model.strip() or st.session_state.hf_selected_api_model | |
| status_placeholder.info(f"Using API Model: {model_id}") | |
| try: | |
| response = client.chat_completion( | |
| model=model_id, | |
| messages=messages, | |
| max_tokens=params['max_new_tokens'], | |
| temperature=params['temperature'], | |
| top_p=params['top_p'], | |
| ) | |
| result_text = response.choices[0].message.content or "" | |
| logger.info(f"HF API text processing successful for model {model_id}.") | |
| except Exception as e: | |
| logger.error(f"HF API text processing failed for model {model_id}: {e}") | |
| result_text = f"Error during Hugging Face API inference: {str(e)}" | |
| else: | |
| status_placeholder.info("Processing text using local model...") | |
| if not _transformers_available: | |
| return "Error: Transformers library not available." | |
| model_path = st.session_state.get('selected_local_model_path') | |
| if not model_path or model_path not in st.session_state.get('local_models', {}): | |
| return "Error: No suitable local model selected." | |
| local_model_data = st.session_state['local_models'][model_path] | |
| if local_model_data.get('type') != 'causal': | |
| return f"Error: Loaded model '{os.path.basename(model_path)}' is not a Causal LM." | |
| status_placeholder.info(f"Using Local Model: {os.path.basename(model_path)}") | |
| model = local_model_data.get('model') | |
| tokenizer = local_model_data.get('tokenizer') | |
| if not model or not tokenizer: | |
| return f"Error: Model or tokenizer not found for {os.path.basename(model_path)}." | |
| try: | |
| try: | |
| prompt_for_model = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| except Exception: | |
| logger.warning(f"Could not apply chat template for {model_path}. Using basic formatting.") | |
| prompt_for_model = f"System: {system_prompt}\nUser: {full_prompt}\nAssistant:" | |
| inputs = tokenizer(prompt_for_model, return_tensors="pt", padding=True, truncation=True, max_length=params['max_new_tokens'] * 2) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| generate_params = { | |
| "max_new_tokens": params['max_new_tokens'], | |
| "temperature": params['temperature'], | |
| "top_p": params['top_p'], | |
| "repetition_penalty": params.get('repetition_penalty', 1.0), | |
| "do_sample": True if params['temperature'] > 0.1 else False, | |
| "pad_token_id": tokenizer.eos_token_id | |
| } | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, **generate_params) | |
| input_length = inputs['input_ids'].shape[1] | |
| generated_ids = outputs[0][input_length:] | |
| result_text = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| logger.info(f"Local text processing successful for model {model_path}.") | |
| except Exception as e: | |
| logger.error(f"Local text processing failed for model {model_path}: {e}") | |
| result_text = f"Error during local model inference: {str(e)}" | |
| elapsed = int(time.time() - start_time) | |
| status_placeholder.success(f"Text processing completed in {elapsed}s.") | |
| return result_text | |
| def process_text_openai(text: str, prompt: str, model_id: str) -> str: | |
| if not _openai_available or not st.session_state.get('openai_client'): | |
| return "Error: OpenAI client not available or API key missing." | |
| status_placeholder = st.empty() | |
| start_time = time.time() | |
| client = st.session_state['openai_client'] | |
| system_prompt = "You are a helpful assistant. Process the following text based on the user's request." | |
| full_prompt = f"{prompt}\n\n---\n\n{text}" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": full_prompt} | |
| ] | |
| status_placeholder.info(f"Processing text using OpenAI model: {model_id}...") | |
| try: | |
| response = client.chat.completions.create( | |
| model=model_id, | |
| messages=messages, | |
| max_tokens=st.session_state.gen_max_tokens, | |
| temperature=st.session_state.gen_temperature, | |
| top_p=st.session_state.gen_top_p, | |
| ) | |
| result_text = response.choices[0].message.content or "" | |
| logger.info(f"OpenAI text processing successful for model {model_id}.") | |
| except Exception as e: | |
| logger.error(f"OpenAI text processing failed for model {model_id}: {e}") | |
| result_text = f"Error during OpenAI inference: {str(e)}" | |
| elapsed = int(time.time() - start_time) | |
| status_placeholder.success(f"Text processing completed in {elapsed}s.") | |
| return result_text | |
| def process_image_hf(image: Image.Image, prompt: str, use_api: bool, model_id: str = None) -> str: | |
| status_placeholder = st.empty() | |
| start_time = time.time() | |
| result_text = "" | |
| if use_api: | |
| status_placeholder.info("Processing image using Hugging Face API...") | |
| client = get_hf_client() | |
| if not client: | |
| return "Error: HF client not configured." | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG" if image.format != 'JPEG' else 'JPEG') | |
| img_bytes = buffered.getvalue() | |
| model_id = model_id or "Salesforce/blip-image-captioning-large" | |
| status_placeholder.info(f"Using API Image-to-Text Model: {model_id}") | |
| try: | |
| response_list = client.image_to_text(data=img_bytes, model=model_id) | |
| if response_list and isinstance(response_list, list) and 'generated_text' in response_list[0]: | |
| result_text = response_list[0]['generated_text'] | |
| logger.info(f"HF API image captioning successful for model {model_id}.") | |
| else: | |
| result_text = "Error: Unexpected response format from image-to-text API." | |
| logger.warning(f"Unexpected API response for image-to-text: {response_list}") | |
| except Exception as e: | |
| logger.error(f"HF API image processing failed: {e}") | |
| result_text = f"Error during Hugging Face API image inference: {str(e)}" | |
| else: | |
| status_placeholder.info("Processing image using local model...") | |
| if not _transformers_available: | |
| return "Error: Transformers library needed." | |
| model_path = st.session_state.get('selected_local_model_path') | |
| if not model_path or model_path not in st.session_state.get('local_models', {}): | |
| return "Error: No suitable local model selected." | |
| local_model_data = st.session_state['local_models'][model_path] | |
| model_type = local_model_data.get('type') | |
| if model_type == 'vision': | |
| processor = local_model_data.get('processor') | |
| model = local_model_data.get('model') | |
| if processor and model: | |
| try: | |
| inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device) | |
| generated_ids = model.generate(**inputs, max_new_tokens=st.session_state.gen_max_tokens) | |
| result_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
| except Exception as e: | |
| result_text = f"Error during local vision model inference: {e}" | |
| else: | |
| result_text = "Error: Processor or model missing for local vision task." | |
| elif model_type == 'ocr': | |
| processor = local_model_data.get('processor') | |
| model = local_model_data.get('model') | |
| if processor and model: | |
| try: | |
| pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(model.device) | |
| generated_ids = model.generate(pixel_values, max_new_tokens=st.session_state.gen_max_tokens) | |
| result_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| except Exception as e: | |
| result_text = f"Error during local OCR model inference: {e}" | |
| else: | |
| result_text = "Error: Processor or model missing for local OCR task." | |
| else: | |
| result_text = f"Error: Loaded model '{os.path.basename(model_path)}' is not a recognized vision/OCR type." | |
| elapsed = int(time.time() - start_time) | |
| status_placeholder.success(f"Image processing completed in {elapsed}s.") | |
| return result_text | |
| def process_image_openai(image: Image.Image, prompt: str, model_id: str = "gpt-4o") -> str: | |
| if not _openai_available or not st.session_state.get('openai_client'): | |
| return "Error: OpenAI client not available or API key missing." | |
| status_placeholder = st.empty() | |
| start_time = time.time() | |
| client = st.session_state['openai_client'] | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_b64 = base64.b64encode(buffered.getvalue()).decode() | |
| status_placeholder.info(f"Processing image using OpenAI model: {model_id}...") | |
| try: | |
| response = client.chat.completions.create( | |
| model=model_id, | |
| messages=[ | |
| {"role": "user", "content": [ | |
| {"type": "text", "text": prompt}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}} | |
| ]} | |
| ], | |
| max_tokens=st.session_state.gen_max_tokens, | |
| temperature=st.session_state.gen_temperature, | |
| ) | |
| result_text = response.choices[0].message.content or "" | |
| logger.info(f"OpenAI image processing successful for model {model_id}.") | |
| except Exception as e: | |
| logger.error(f"OpenAI image processing failed for model {model_id}: {e}") | |
| result_text = f"Error during OpenAI image inference: {str(e)}" | |
| elapsed = int(time.time() - start_time) | |
| status_placeholder.success(f"Image processing completed in {elapsed}s.") | |
| return result_text | |
| async def process_hf_ocr(image: Image.Image, output_file: str, use_api: bool, model_id: str = None) -> str: | |
| ocr_prompt = "Extract text content from this image." | |
| result = process_image_hf(image, ocr_prompt, use_api, model_id=model_id or "microsoft/trocr-large-handwritten") | |
| if result and not result.startswith("Error") and not result.startswith("["): | |
| try: | |
| async with aiofiles.open(output_file, "w", encoding='utf-8') as f: | |
| await f.write(result) | |
| logger.info(f"HF OCR result saved to {output_file}") | |
| except IOError as e: | |
| logger.error(f"Failed to save HF OCR output to {output_file}: {e}") | |
| result += f"\n[Error saving file: {e}]" | |
| elif os.path.exists(output_file): | |
| try: | |
| os.remove(output_file) | |
| except OSError: | |
| pass | |
| return result | |
| async def process_openai_ocr(image: Image.Image, output_file: str, model_id: str = "gpt-4o") -> str: | |
| ocr_prompt = "Extract text content from this image." | |
| result = process_image_openai(image, ocr_prompt, model_id) | |
| if result and not result.startswith("Error"): | |
| try: | |
| async with aiofiles.open(output_file, "w", encoding='utf-8') as f: | |
| await f.write(result) | |
| logger.info(f"OpenAI OCR result saved to {output_file}") | |
| except IOError as e: | |
| logger.error(f"Failed to save OpenAI OCR output to {output_file}: {e}") | |
| result += f"\n[Error saving file: {e}]" | |
| elif os.path.exists(output_file): | |
| try: | |
| os.remove(output_file) | |
| except OSError: | |
| pass | |
| return result | |
| def randomize_character_content(): | |
| intro_templates = [ | |
| "{char} is a valiant knight...", "{char} is a mischievous thief...", | |
| "{char} is a wise scholar...", "{char} is a fiery warrior...", "{char} is a gentle healer..." | |
| ] | |
| greeting_templates = [ | |
| "'I am from the knight's guild...'", "'I heard you needed helpβnameβs {char}...", | |
| "'Oh, hello! Iβm {char}, didnβt see you there...'", "'Iβm {char}, and Iβm here to fight...'", | |
| "'Iβm {char}, here to heal...'" | |
| ] | |
| name = f"Character_{random.randint(1000, 9999)}" | |
| gender = random.choice(["Male", "Female"]) | |
| intro = random.choice(intro_templates).format(char=name) | |
| greeting = random.choice(greeting_templates).format(char=name) | |
| return name, gender, intro, greeting | |
| def save_character(character_data): | |
| characters = st.session_state.get('characters', []) | |
| if any(c['name'] == character_data['name'] for c in characters): | |
| st.error(f"Character name '{character_data['name']}' already exists.") | |
| return False | |
| characters.append(character_data) | |
| st.session_state['characters'] = characters | |
| try: | |
| with open("characters.json", "w", encoding='utf-8') as f: | |
| json.dump(characters, f, indent=2) | |
| logger.info(f"Saved character: {character_data['name']}") | |
| return True | |
| except IOError as e: | |
| logger.error(f"Failed to save characters.json: {e}") | |
| st.error(f"Failed to save character file: {e}") | |
| return False | |
| def load_characters(): | |
| if not os.path.exists("characters.json"): | |
| st.session_state['characters'] = [] | |
| return | |
| try: | |
| with open("characters.json", "r", encoding='utf-8') as f: | |
| characters = json.load(f) | |
| if isinstance(characters, list): | |
| st.session_state['characters'] = characters | |
| logger.info(f"Loaded {len(characters)} characters.") | |
| else: | |
| st.session_state['characters'] = [] | |
| logger.warning("characters.json is not a list, resetting.") | |
| os.remove("characters.json") | |
| except (json.JSONDecodeError, IOError) as e: | |
| logger.error(f"Failed to load or decode characters.json: {e}") | |
| st.error(f"Error loading character file: {e}. Starting fresh.") | |
| st.session_state['characters'] = [] | |
| try: | |
| corrupt_filename = f"characters_corrupt_{int(time.time())}.json" | |
| shutil.copy("characters.json", corrupt_filename) | |
| logger.info(f"Backed up corrupted character file to {corrupt_filename}") | |
| os.remove("characters.json") | |
| except Exception as backup_e: | |
| logger.error(f"Could not backup corrupted character file: {backup_e}") | |
| def clean_stem(fn: str) -> str: | |
| name = os.path.splitext(os.path.basename(fn))[0] | |
| name = name.replace('-', ' ').replace('_', ' ') | |
| return name.strip().title() | |
| def make_image_sized_pdf(sources, is_markdown_flags): | |
| if not sources: | |
| st.warning("No sources provided for PDF generation.") | |
| return None | |
| buf = BytesIO() | |
| styles = getSampleStyleSheet() | |
| md_style = ParagraphStyle( | |
| name='Markdown', | |
| fontSize=10, | |
| leading=12, | |
| spaceAfter=6, | |
| alignment=TA_JUSTIFY, | |
| fontName='Helvetica' | |
| ) | |
| doc = SimpleDocTemplate(buf, pagesize=letter, rightMargin=36, leftMargin=36, topMargin=36, bottomMargin=36) | |
| story = [] | |
| try: | |
| for idx, (src, is_md) in enumerate(zip(sources, is_markdown_flags), start=1): | |
| status_placeholder = st.empty() | |
| filename = 'page_' + str(idx) | |
| status_placeholder.info(f"Adding page {idx}/{len(sources)}: {os.path.basename(str(src))}...") | |
| try: | |
| if is_md: | |
| with open(src, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| content = re.sub(r'!\[.*?\]\(.*?\)', '', content) | |
| paragraphs = content.split('\n\n') | |
| for para in paragraphs: | |
| if para.strip(): | |
| story.append(Paragraph(para.strip(), md_style)) | |
| story.append(PageBreak()) | |
| status_placeholder.success(f"Added markdown page {idx}/{len(sources)}: {filename}") | |
| else: | |
| if isinstance(src, str): | |
| if not os.path.exists(src): | |
| logger.warning(f"Image file not found: {src}. Skipping.") | |
| status_placeholder.warning(f"Skipping missing file: {os.path.basename(src)}") | |
| continue | |
| img_obj = Image.open(src) | |
| filename = os.path.basename(src) | |
| else: | |
| src.seek(0) | |
| img_obj = Image.open(src) | |
| filename = getattr(src, 'name', f'uploaded_image_{idx}') | |
| src.seek(0) | |
| with img_obj: | |
| iw, ih = img_obj.size | |
| if iw <= 0 or ih <= 0: | |
| logger.warning(f"Invalid image dimensions ({iw}x{ih}) for {filename}. Skipping.") | |
| status_placeholder.warning(f"Skipping invalid image: {filename}") | |
| continue | |
| cap_h = 30 | |
| c = canvas.Canvas(BytesIO(), pagesize=(iw, ih + cap_h)) | |
| img_reader = ImageReader(img_obj) | |
| c.drawImage(img_reader, 0, cap_h, width=iw, height=ih, preserveAspectRatio=True, anchor='c', mask='auto') | |
| caption = clean_stem(filename) | |
| c.setFont('Helvetica', 12) | |
| c.setFillColorRGB(0, 0, 0) | |
| c.drawCentredString(iw / 2, cap_h / 2 + 3, caption) | |
| c.setFont('Helvetica', 8) | |
| c.setFillColorRGB(0.5, 0.5, 0.5) | |
| c.drawRightString(iw - 10, 8, f"Page {idx}") | |
| c.save() | |
| story.append(PageBreak()) | |
| status_placeholder.success(f"Added image page {idx}/{len(sources)}: {filename}") | |
| except Exception as e: | |
| logger.error(f"Error processing source {src}: {e}") | |
| status_placeholder.error(f"Error adding page {idx}: {e}") | |
| doc.build(story) | |
| buf.seek(0) | |
| if buf.getbuffer().nbytes < 100: | |
| st.error("PDF generation resulted in an empty file.") | |
| return None | |
| return buf.getvalue() | |
| except Exception as e: | |
| logger.error(f"Fatal error during PDF generation: {e}") | |
| st.error(f"PDF Generation Failed: {e}") | |
| return None | |
| def update_gallery(gallery_type='image'): | |
| container = st.session_state['asset_gallery_container'][gallery_type] | |
| with container: | |
| st.markdown(f"### {gallery_type.capitalize()} Gallery πΈ") | |
| files = get_typed_gallery_files(gallery_type) | |
| if not files: | |
| st.info(f"No {gallery_type} assets found yet.") | |
| return | |
| st.caption(f"Found {len(files)} assets:") | |
| for idx, file in enumerate(files[:st.session_state.gallery_size]): | |
| st.session_state['unique_counter'] += 1 | |
| unique_id = st.session_state['unique_counter'] | |
| item_key_base = f"{gallery_type}_gallery_item_{os.path.basename(file)}_{unique_id}" | |
| basename = os.path.basename(file) | |
| st.markdown(f"**{basename}**") | |
| try: | |
| file_ext = os.path.splitext(file)[1].lower() | |
| if gallery_type == 'image' and file_ext in ['.png', '.jpg', '.jpeg']: | |
| with st.expander("Preview", expanded=False): | |
| st.image(Image.open(file), use_container_width=True) | |
| elif gallery_type == 'pdf' and file_ext == '.pdf': | |
| with st.expander("Preview (Page 1)", expanded=False): | |
| doc = fitz.open(file) | |
| if len(doc) > 0: | |
| pix = doc[0].get_pixmap(matrix=fitz.Matrix(0.5, 0.5)) | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| st.image(img, use_container_width=True) | |
| else: | |
| st.warning("Empty PDF") | |
| doc.close() | |
| elif gallery_type == 'md' and file_ext == '.md': | |
| with st.expander("Preview (Start)", expanded=False): | |
| with open(file, 'r', encoding='utf-8', errors='ignore') as f: | |
| content_preview = f.read(200) | |
| st.code(content_preview + "...", language='markdown') | |
| action_cols = st.columns(3) | |
| with action_cols[0]: | |
| checkbox_key = f"cb_{item_key_base}" | |
| st.session_state['asset_checkboxes'][gallery_type][file] = st.checkbox( | |
| "Select", | |
| value=st.session_state['asset_checkboxes'][gallery_type].get(file, False), | |
| key=checkbox_key | |
| ) | |
| with action_cols[1]: | |
| mime_map = {'.png': 'image/png', '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.pdf': 'application/pdf', '.md': 'text/markdown'} | |
| mime_type = mime_map.get(file_ext, "application/octet-stream") | |
| dl_key = f"dl_{item_key_base}" | |
| try: | |
| with open(file, "rb") as fp: | |
| st.download_button( | |
| label="π₯", | |
| data=fp, | |
| file_name=basename, | |
| mime=mime_type, | |
| key=dl_key, | |
| help="Download this file" | |
| ) | |
| except Exception as dl_e: | |
| st.error(f"Download Error: {dl_e}") | |
| with action_cols[2]: | |
| delete_key = f"del_{item_key_base}" | |
| if st.button("ποΈ", key=delete_key, help=f"Delete {basename}"): | |
| try: | |
| os.remove(file) | |
| st.session_state['asset_checkboxes'][gallery_type].pop(file, None) | |
| if file in st.session_state.get('layout_snapshots', []): | |
| st.session_state['layout_snapshots'].remove(file) | |
| logger.info(f"Deleted {gallery_type} asset: {file}") | |
| st.toast(f"Deleted {basename}!", icon="β ") | |
| st.rerun() | |
| except OSError as e: | |
| logger.error(f"Error deleting file {file}: {e}") | |
| st.error(f"Could not delete {basename}") | |
| except Exception as e: | |
| st.error(f"Error displaying {basename}: {e}") | |
| logger.error(f"Error displaying asset {file}: {e}") | |
| st.markdown("---") | |
| # --- UI Elements --- | |
| st.sidebar.subheader("π€ AI Settings") | |
| with st.sidebar.expander("API Inference Settings", expanded=False): | |
| st.session_state.hf_custom_key = st.text_input( | |
| "Custom HF Token", | |
| value=st.session_state.get('hf_custom_key', ""), | |
| type="password", | |
| key="hf_custom_key_input" | |
| ) | |
| token_status = "Custom Key Set" if st.session_state.hf_custom_key else ("Default HF_TOKEN Set" if HF_TOKEN else "No Token Set") | |
| st.caption(f"HF Token Status: {token_status}") | |
| providers_list = ["hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"] | |
| st.session_state.hf_provider = st.selectbox( | |
| "HF Inference Provider", | |
| options=providers_list, | |
| index=providers_list.index(st.session_state.get('hf_provider', DEFAULT_PROVIDER)), | |
| key="hf_provider_select" | |
| ) | |
| st.session_state.hf_custom_api_model = st.text_input( | |
| "Custom HF API Model ID", | |
| value=st.session_state.get('hf_custom_api_model', ""), | |
| key="hf_custom_model_input" | |
| ) | |
| effective_hf_model = st.session_state.hf_custom_api_model.strip() or st.session_state.hf_selected_api_model | |
| st.session_state.hf_selected_api_model = st.selectbox( | |
| "Featured HF API Model", | |
| options=FEATURED_MODELS_LIST, | |
| index=FEATURED_MODELS_LIST.index(st.session_state.get('hf_selected_api_model', FEATURED_MODELS_LIST[0])), | |
| key="hf_featured_model_select" | |
| ) | |
| st.caption(f"Effective HF API Model: {effective_hf_model}") | |
| if _openai_available: | |
| st.session_state.openai_selected_model = st.selectbox( | |
| "OpenAI Model", | |
| options=OPENAI_MODELS_LIST, | |
| index=OPENAI_MODELS_LIST.index(st.session_state.get('openai_selected_model', OPENAI_MODELS_LIST[0])), | |
| key="openai_model_select" | |
| ) | |
| with st.sidebar.expander("Local Model Selection", expanded=True): | |
| if not _transformers_available: | |
| st.warning("Transformers library not found. Cannot load local models.") | |
| else: | |
| local_model_options = ["None"] + list(st.session_state.get('local_models', {}).keys()) | |
| current_selection = st.session_state.get('selected_local_model_path', "None") | |
| if current_selection not in local_model_options: | |
| current_selection = "None" | |
| selected_path = st.selectbox( | |
| "Active Local Model", | |
| options=local_model_options, | |
| index=local_model_options.index(current_selection), | |
| format_func=lambda x: os.path.basename(x) if x != "None" else "None", | |
| key="local_model_selector" | |
| ) | |
| st.session_state.selected_local_model_path = selected_path if selected_path != "None" else None | |
| if st.session_state.selected_local_model_path: | |
| model_info = st.session_state.local_models[st.session_state.selected_local_model_path] | |
| st.caption(f"Type: {model_info.get('type', 'Unknown')}") | |
| st.caption(f"Device: {model_info.get('model').device if model_info.get('model') else 'N/A'}") | |
| else: | |
| st.caption("No local model selected.") | |
| with st.sidebar.expander("Generation Parameters", expanded=False): | |
| st.session_state.gen_max_tokens = st.slider("Max New Tokens", 1, 4096, st.session_state.get('gen_max_tokens', 512), key="param_max_tokens") | |
| st.session_state.gen_temperature = st.slider("Temperature", 0.01, 2.0, st.session_state.get('gen_temperature', 0.7), step=0.01, key="param_temp") | |
| st.session_state.gen_top_p = st.slider("Top-P", 0.01, 1.0, st.session_state.get('gen_top_p', 0.95), step=0.01, key="param_top_p") | |
| st.session_state.gen_frequency_penalty = st.slider("Repetition Penalty", 0.0, 1.0, st.session_state.get('gen_frequency_penalty', 0.0), step=0.05, key="param_repetition") | |
| st.session_state.gen_seed = st.slider("Seed", -1, 65535, st.session_state.get('gen_seed', -1), step=1, key="param_seed") | |
| st.sidebar.subheader("πΌοΈ Gallery Settings") | |
| st.slider( | |
| "Max Items Shown", | |
| min_value=2, | |
| max_value=50, | |
| value=st.session_state.get('gallery_size', 10), | |
| key="gallery_size_slider" | |
| ) | |
| st.session_state.gallery_size = st.session_state.gallery_size_slider | |
| st.sidebar.markdown("---") | |
| update_gallery('image') | |
| update_gallery('md') | |
| update_gallery('pdf') | |
| # --- Main Application --- | |
| st.title("Vision & Layout Titans ππΌοΈπ") | |
| st.markdown("Create PDFs from images and markdown, process with AI, and manage characters.") | |
| tabs = st.tabs([ | |
| "Image/MD->PDF Layout πΌοΈβ‘οΈπ", | |
| "Camera Snap π·", | |
| "Download PDFs π₯", | |
| "Build Titan (Local Models) π±", | |
| "PDF Process (AI) π", | |
| "Image Process (AI) πΌοΈ", | |
| "Text Process (AI) π", | |
| "Test OCR (AI) π", | |
| "Test Image Gen (Diffusers) π¨", | |
| "Character Editor π§βπ¨", | |
| "Character Gallery πΌοΈ" | |
| ]) | |
| with tabs[0]: | |
| st.header("Image/Markdown to PDF Layout Generator") | |
| st.markdown("Select images and markdown files, reorder them, and generate a PDF.") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("A. Select Assets") | |
| selected_images = [f for f in get_typed_gallery_files('image') if st.session_state['asset_checkboxes']['image'].get(f, False)] | |
| selected_mds = [f for f in get_typed_gallery_files('md') if st.session_state['asset_checkboxes']['md'].get(f, False)] | |
| st.write(f"Selected Images: {len(selected_images)}") | |
| st.write(f"Selected Markdown Files: {len(selected_mds)}") | |
| with col2: | |
| st.subheader("B. Review and Reorder") | |
| layout_records = [] | |
| for idx, path in enumerate(selected_images + selected_mds, start=1): | |
| is_md = path in selected_mds | |
| try: | |
| if is_md: | |
| with open(path, 'r', encoding='utf-8') as f: | |
| content = f.read(50) | |
| layout_records.append({ | |
| "filename": os.path.basename(path), | |
| "source": path, | |
| "type": "Markdown", | |
| "preview": content + "...", | |
| "order": idx | |
| }) | |
| else: | |
| with Image.open(path) as im: | |
| w, h = im.size | |
| ar = round(w / h, 2) if h > 0 else 0 | |
| orient = "Square" if 0.9 <= ar <= 1.1 else ("Landscape" if ar > 1.1 else "Portrait") | |
| layout_records.append({ | |
| "filename": os.path.basename(path), | |
| "source": path, | |
| "type": "Image", | |
| "width": w, | |
| "height": h, | |
| "aspect_ratio": ar, | |
| "orientation": orient, | |
| "order": idx | |
| }) | |
| except Exception as e: | |
| logger.warning(f"Could not process {path}: {e}") | |
| st.warning(f"Skipping invalid file: {os.path.basename(path)}") | |
| if not layout_records: | |
| st.infoperiod |