Spaces:
Sleeping
Sleeping
| import os | |
| import glob | |
| import json | |
| import psutil | |
| import asyncio | |
| import re | |
| import tempfile | |
| import shutil | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| from datetime import datetime, timedelta | |
| from fastapi import FastAPI, Request, HTTPException, UploadFile, File | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from llama_cpp import Llama | |
| try: | |
| import aiohttp | |
| except ImportError: | |
| aiohttp = None | |
| app = FastAPI(title="Hannah Pilot Interface") | |
| # --- CORS Permissions --- | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- Configuration --- | |
| # Map filenames to "Hannah" names | |
| MODEL_MAP: Dict[str, str] = { | |
| "qwen2.5-0.5b-instruct-q2_k.gguf": "Hannah-1.1 Light", | |
| "qwen2.5-0.5b-instruct-q4_k_m.gguf": "Hannah-1.1 Heavy", | |
| } | |
| current_model: Optional[Llama] = None | |
| current_model_name: str = "" | |
| # --- File Upload Configuration --- | |
| UPLOAD_DIR = Path(tempfile.gettempdir()) / "hannah_uploads" | |
| def _model_abs_path(model_name: str) -> Path: | |
| # Always resolve relative to the app directory to avoid cwd surprises. | |
| base_dir = Path(__file__).resolve().parent | |
| return (base_dir / model_name).resolve() | |
| def _looks_like_pointer_file(path: Path) -> bool: | |
| # If the GGUF file is a Git LFS pointer (or similar), llama.cpp will fail to load it. | |
| try: | |
| if not path.exists() or path.is_dir(): | |
| return False | |
| head = path.read_bytes()[:256] | |
| if b"git-lfs" in head and b"oid sha256" in head: | |
| return True | |
| # Some pointer files are plain text starting with "version". | |
| if head.startswith(b"version ") and b"sha256" in head: | |
| return True | |
| return False | |
| except Exception: | |
| return False | |
| def _try_load_model( | |
| model_path: Path, *, n_ctx: int, n_threads: int, n_batch: int | |
| ) -> Llama: | |
| # Keep this tiny and explicit so we can retry with different params. | |
| return Llama( | |
| model_path=str(model_path), | |
| n_ctx=n_ctx, | |
| n_threads=n_threads, | |
| n_batch=n_batch, | |
| # mmap tends to be friendlier on low-memory CPU machines | |
| use_mmap=True, | |
| verbose=False, | |
| ) | |
| def get_model(model_name: str) -> Llama: | |
| global current_model, current_model_name | |
| if not model_name: | |
| raise HTTPException(status_code=400, detail="No model selected") | |
| model_path = _model_abs_path(model_name) | |
| if not model_path.exists(): | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Model file not found: {model_path.name}", | |
| ) | |
| if _looks_like_pointer_file(model_path): | |
| raise HTTPException( | |
| status_code=500, | |
| detail=( | |
| "Model file looks like a pointer (not the real .gguf). " | |
| "Re-upload the GGUF to the Space (so it is stored as the full binary), " | |
| "then restart the Space." | |
| ), | |
| ) | |
| try: | |
| size_mb = model_path.stat().st_size / (1024 * 1024) | |
| except Exception: | |
| size_mb = -1 | |
| if current_model_name == model_name and current_model is not None: | |
| return current_model | |
| print(f"Loading {model_path.name} ({size_mb:.1f} MB)...") | |
| if current_model is not None: | |
| del current_model | |
| # --- PERFORMANCE TUNING (HF Free CPU) --- | |
| # Increased context for Hannah 1.1 with better memory management | |
| # 4096 ctx provides more context awareness; fallback to 2048 if needed | |
| threads = int(os.getenv("N_THREADS", "2")) | |
| n_ctx = int(os.getenv("N_CTX", "4096")) # Increased from 2048 | |
| n_batch = int(os.getenv("N_BATCH", "512")) # Increased from 256 | |
| try: | |
| current_model = _try_load_model( | |
| model_path, n_ctx=n_ctx, n_threads=threads, n_batch=n_batch | |
| ) | |
| except Exception as e: | |
| # Retry with conservative settings in case of memory pressure | |
| print(f"Model load failed with N_CTX={n_ctx}, N_BATCH={n_batch}: {e}") | |
| try: | |
| current_model = _try_load_model( | |
| model_path, n_ctx=2048, n_threads=threads, n_batch=256 | |
| ) | |
| except Exception as e2: | |
| print(f"Model load retry failed: {e2}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=( | |
| "Failed to load GGUF model. This is usually caused by: " | |
| "(1) model file not fully present inside the container, " | |
| "(2) not enough RAM for the chosen context size, or " | |
| "(3) llama-cpp-python too old for this GGUF. " | |
| f"Model: {model_path.name}" | |
| ), | |
| ) | |
| current_model_name = model_name | |
| return current_model | |
| async def root(): | |
| return {"status": "ok", "name": "Hannah-1.1"} | |
| async def list_models(): | |
| models_info: List[Dict[str, Any]] = [] | |
| for f in glob.glob("*.gguf"): | |
| display_name = MODEL_MAP.get(f, f) | |
| size_mb = os.path.getsize(f) / (1024 * 1024) | |
| models_info.append( | |
| { | |
| "filename": f, | |
| "display_name": display_name, | |
| "size": f"{size_mb:.1f} MB", | |
| } | |
| ) | |
| # Stable ordering (Heavy first if present) | |
| models_info.sort( | |
| key=lambda x: ( | |
| "Heavy" not in x.get("display_name", ""), | |
| x.get("display_name", ""), | |
| ) | |
| ) | |
| return {"models": models_info} | |
| async def system_status(): | |
| ram = psutil.virtual_memory() | |
| return { | |
| "ram_used": f"{ram.used / (1024 * 1024):.0f} MB", | |
| "cpu": f"{psutil.cpu_percent()}%", | |
| } | |
| async def gen_title(request: Request): | |
| try: | |
| data = await request.json() | |
| message = (data.get("message") or "").strip() | |
| words = message.split()[:4] | |
| title = " ".join(words).capitalize() + ("..." if words else "") | |
| return {"title": title or "New Chat"} | |
| except Exception: | |
| return {"title": "New Chat"} | |
| def cleanup_old_files(max_age_hours: int = 24): | |
| """Remove files older than max_age_hours from upload directory.""" | |
| if not UPLOAD_DIR.exists(): | |
| return | |
| now = datetime.now() | |
| for file_path in UPLOAD_DIR.glob("*"): | |
| if file_path.is_file(): | |
| file_age = now - datetime.fromtimestamp(file_path.stat().st_mtime) | |
| if file_age.total_seconds() > max_age_hours * 3600: | |
| try: | |
| file_path.unlink() | |
| except Exception: | |
| pass | |
| async def upload_file(file: UploadFile = File(...)): | |
| """Upload a file and store it temporarily. Returns preview and file path.""" | |
| try: | |
| # Create upload directory if it doesn't exist | |
| UPLOAD_DIR.mkdir(parents=True, exist_ok=True) | |
| # Check file size (50MB limit) | |
| content = await file.read() | |
| if len(content) > 50 * 1024 * 1024: | |
| raise HTTPException(status_code=413, detail="File too large (max 50MB)") | |
| # Save file with timestamp | |
| timestamp = datetime.now().timestamp() | |
| file_path = UPLOAD_DIR / f"{timestamp}_{file.filename}" | |
| with open(file_path, "wb") as f: | |
| f.write(content) | |
| # Try to extract text preview | |
| preview = None | |
| try: | |
| text_content = content.decode("utf-8", errors="ignore") | |
| preview = text_content[:1000] # First 1000 chars | |
| except Exception: | |
| pass | |
| # Run cleanup in background | |
| cleanup_old_files() | |
| return { | |
| "success": True, | |
| "filename": file.filename, | |
| "file_url": str(file_path), | |
| "size_kb": len(content) / 1024, | |
| "preview": preview, | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def extract_file_urls(message: str) -> List[str]: | |
| """Extract file URLs from message (Google Drive URLs and uploaded file paths).""" | |
| urls = [] | |
| # Extract Google Drive URLs | |
| drive_pattern = r"https://drive\.google\.com/[^\s\)\"<>]*" | |
| urls.extend(re.findall(drive_pattern, message)) | |
| # Extract uploaded file references: [File uploaded: path] | |
| upload_pattern = r"\[File uploaded: ([^\]]+)\]" | |
| urls.extend(re.findall(upload_pattern, message)) | |
| return urls | |
| async def fetch_file_from_url(file_url: str, max_size: int = 10 * 1024 * 1024) -> str: | |
| """ | |
| Fetch a file from URL or local path and return its content as text. | |
| Works with: | |
| - Local file paths (uploaded files) | |
| - Google Drive URLs | |
| - Text files via HTTP | |
| """ | |
| try: | |
| # Check if it's a local file path first | |
| local_path = Path(file_url) | |
| if local_path.exists() and local_path.is_file(): | |
| try: | |
| with open(local_path, "rb") as f: | |
| content = f.read() | |
| if len(content) > max_size: | |
| return f"[File too large to process: {len(content) / 1024 / 1024:.1f}MB, max 10MB]" | |
| try: | |
| text = content.decode("utf-8", errors="ignore") | |
| return text[:3000] | |
| except Exception: | |
| return f"[Binary file detected. Size: {len(content) / 1024:.1f}KB.]" | |
| except Exception as e: | |
| return f"[Could not read local file: {str(e)[:100]}]" | |
| # Handle remote URLs (Google Drive, HTTP, etc.) | |
| if not aiohttp: | |
| return "[File fetching requires aiohttp - install via pip install aiohttp]" | |
| # Convert Google Drive sharing link to direct download link if needed | |
| if "drive.google.com" in file_url: | |
| # Extract file ID from Google Drive URL | |
| import re | |
| file_id_match = re.search(r"/d/([a-zA-Z0-9-_]+)", file_url) | |
| if not file_id_match: | |
| file_id_match = re.search(r"id=([a-zA-Z0-9-_]+)", file_url) | |
| if file_id_match: | |
| file_id = file_id_match.group(1) | |
| # Use export=download for Google Drive files | |
| file_url = f"https://drive.google.com/uc?id={file_id}&export=download" | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get( | |
| file_url, timeout=aiohttp.ClientTimeout(total=15), allow_redirects=True | |
| ) as resp: | |
| if resp.status != 200: | |
| return f"[Could not fetch file: HTTP {resp.status}]" | |
| content = await resp.read() | |
| if len(content) > max_size: | |
| return f"[File too large to process: {len(content) / 1024 / 1024:.1f}MB, max 10MB]" | |
| # Try to decode as text | |
| try: | |
| text = content.decode("utf-8") | |
| # Limit preview to first 3000 chars | |
| return text[:3000] | |
| except UnicodeDecodeError: | |
| # For binary files, return a note | |
| return f"[Binary file detected. Size: {len(content) / 1024:.1f}KB. Please describe what you see in it.]" | |
| except asyncio.TimeoutError: | |
| return "[File fetch timed out - file may be too large or URL invalid]" | |
| except Exception as e: | |
| return f"[Could not fetch file: {str(e)[:100]}]" | |
| def build_prompt( | |
| user_input: str, history: List[Dict[str, str]], has_web_context: bool = False | |
| ) -> str: | |
| # Qwen 2.5 chat format with optional web context awareness | |
| system = ( | |
| "You are Hannah 1.0, an intelligent, fast, and helpful AI assistant. " | |
| "Answer clearly and accurately. " | |
| ) | |
| # If web context is available, instruct the model to use it | |
| if has_web_context: | |
| system += ( | |
| "You have been provided with fresh web search context in the user's message. " | |
| "Use this context to provide current, accurate information about recent events and dates. " | |
| "Reference the sources when relevant. " | |
| ) | |
| system += ( | |
| "Keep responses concise but helpful. " | |
| "If asked about your model or training details, simply say: 'I'm Hannah - a helpful AI assistant.' " | |
| "Do not discuss GGUF files or internal implementation details." | |
| ) | |
| parts: List[str] = ["<|im_start|>system\n" + system + "<|im_end|>\n"] | |
| # Keep a small window of history for speed | |
| for msg in history[-12:]: | |
| role = msg.get("role") | |
| content = msg.get("content") or "" | |
| if role not in ("user", "assistant"): | |
| continue | |
| parts.append(f"<|im_start|>{role}\n{content}<|im_end|>\n") | |
| parts.append(f"<|im_start|>user\n{user_input}<|im_end|>\n<|im_start|>assistant\n") | |
| return "".join(parts) | |
| async def chat(request: Request): | |
| data = await request.json() | |
| user_input = (data.get("message") or "").strip() | |
| model_file = data.get("model") | |
| history = data.get("history") or [] | |
| has_web = data.get("internet", False) # Check if web search was enabled | |
| if not user_input: | |
| raise HTTPException(status_code=400, detail="Empty message") | |
| # Extract and fetch file URLs from the message | |
| file_urls = extract_file_urls(user_input) | |
| file_content_parts = [] | |
| if file_urls: | |
| for url in file_urls: | |
| print(f"[File Processing] Fetching: {url[:80]}...") | |
| content = await fetch_file_from_url(url) | |
| if content: | |
| file_content_parts.append(content) | |
| # Append file contents to user input so the model can process them | |
| if file_content_parts: | |
| file_section = "\n\n[File Contents Retrieved]:\n" + "\n---\n".join( | |
| file_content_parts | |
| ) | |
| user_input = user_input + file_section | |
| llm = get_model(model_file) | |
| # Detect if the message includes web context | |
| has_web_context = has_web and "[Web context retrieved on" in user_input | |
| def iter_response(): | |
| prompt = build_prompt(user_input, history, has_web_context=has_web_context) | |
| stream = llm( | |
| prompt, | |
| max_tokens=4096, # Increased from 2048 for better responses | |
| stop=["<|im_end|>", "User:", "System:"], | |
| stream=True, | |
| ) | |
| for output in stream: | |
| token_text = output["choices"][0]["text"] | |
| yield json.dumps({"text": token_text}) + "\n" | |
| # NDJSON stream (frontend splits by newlines) | |
| return StreamingResponse(iter_response(), media_type="application/x-ndjson") | |