Sabithulla commited on
Commit
2a72045
·
1 Parent(s): 6e4922a

Add FastAPI backend with Docker for HuggingFace Spaces

Browse files
Files changed (6) hide show
  1. Dockerfile +29 -0
  2. database.py +50 -0
  3. main.py +127 -0
  4. model_manager.py +192 -0
  5. ocr_engine.py +30 -0
  6. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies for llama-cpp and image processing
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ libopenblas-dev \
9
+ tesseract-ocr \
10
+ libtesseract-dev \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy requirements first for better caching
14
+ COPY requirements.txt .
15
+
16
+ # Install Python dependencies
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Copy application code
20
+ COPY . .
21
+
22
+ # Create models directory
23
+ RUN mkdir -p models
24
+
25
+ # Expose port 7860 (HuggingFace Spaces default)
26
+ EXPOSE 7860
27
+
28
+ # Run Uvicorn on port 7860
29
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--timeout-keep-alive", "75"]
database.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from supabase import create_client, Client
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+
7
+ class DatabaseManager:
8
+ def __init__(self):
9
+ url = os.environ.get("SUPABASE_URL")
10
+ key = os.environ.get("SUPABASE_KEY")
11
+ if url and key:
12
+ self.supabase: Client = create_client(url, key)
13
+ else:
14
+ self.supabase = None
15
+ print("Warning: Supabase credentials missing. Database functionality will be disabled.")
16
+
17
+ def store_message(self, user_id: str, role: str, content: str, model_used: str):
18
+ if not self.supabase:
19
+ return None
20
+
21
+ data = {
22
+ "user_id": user_id,
23
+ "role": role,
24
+ "content": content,
25
+ "model_used": model_used
26
+ }
27
+ return self.supabase.table("messages").insert(data).execute()
28
+
29
+ def get_history(self, user_id: str):
30
+ if not self.supabase:
31
+ return []
32
+
33
+ # History is fetched from the last 24 hours
34
+ return self.supabase.table("messages")\
35
+ .select("*")\
36
+ .eq("user_id", user_id)\
37
+ .order("created_at", desc=False)\
38
+ .execute()
39
+
40
+ def cleanup_old_messages(self):
41
+ if not self.supabase:
42
+ return None
43
+
44
+ # This can be called by a cron job
45
+ # In SQL: DELETE FROM messages WHERE created_at < NOW() - INTERVAL '1 day';
46
+ # We can trigger an RPC or just use a raw delete if Supabase client allows it
47
+ # Here we'll just mock it or provide instructions for Supabase edge functions
48
+ pass
49
+
50
+ db_manager = DatabaseManager()
main.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Body, HTTPException, Request
2
+ from fastapi.responses import StreamingResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
+ import uvicorn
6
+ import os
7
+ import json
8
+ import sys
9
+ from dotenv import load_dotenv
10
+ from typing import Optional, List
11
+ import logging
12
+
13
+ from model_manager import model_manager
14
+ from ocr_engine import ocr_engine
15
+ from database import db_manager
16
+
17
+ # Setup logging
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
21
+ stream=sys.stdout
22
+ )
23
+ logger = logging.getLogger(__name__)
24
+
25
+ load_dotenv()
26
+
27
+ app = FastAPI(title="AI Platform API")
28
+
29
+ # Configure CORS
30
+ app.add_middleware(
31
+ CORSMiddleware,
32
+ allow_origins=[
33
+ "https://frontend-one-gamma-14.vercel.app",
34
+ "http://localhost:3000", # For local development
35
+ "http://localhost:8000"
36
+ ],
37
+ allow_credentials=True,
38
+ allow_methods=["*"],
39
+ allow_headers=["*"],
40
+ )
41
+
42
+ @app.get("/")
43
+ async def root():
44
+ return {
45
+ "name": "Alpha Core AI API",
46
+ "version": "1.0.0",
47
+ "status": "online",
48
+ "endpoints": {
49
+ "health": "/health",
50
+ "chat": "/chat",
51
+ "upload": "/upload-image",
52
+ "cleanup": "/cleanup"
53
+ }
54
+ }
55
+
56
+ @app.get("/health")
57
+ async def health_check():
58
+ return {"status": "healthy", "version": "1.0.0"}
59
+
60
+ class ChatRequest(BaseModel):
61
+ message: str
62
+ model: str = "tinyllama"
63
+ user_id: str = "default_user"
64
+ context: Optional[List[dict]] = None
65
+ temperature: Optional[float] = 0.7
66
+ top_p: Optional[float] = 0.95
67
+ max_tokens: Optional[int] = 2048
68
+ repeat_penalty: Optional[float] = 1.1
69
+
70
+ @app.post("/chat")
71
+ async def chat_endpoint(request: ChatRequest):
72
+ try:
73
+ logger.info(f"Chat request: model={request.model}, user={request.user_id}")
74
+
75
+ def stream_response():
76
+ full_response = ""
77
+ try:
78
+ # Pass context and settings to model manager for memory
79
+ params = {
80
+ "temperature": request.temperature,
81
+ "top_p": request.top_p,
82
+ "max_tokens": request.max_tokens,
83
+ "repeat_penalty": request.repeat_penalty
84
+ }
85
+ for token in model_manager.generate_stream(request.model, request.message, request.context, **params):
86
+ full_response += token
87
+ yield f"data: {json.dumps({'token': token})}\n\n"
88
+
89
+ logger.info(f"Response generated: {len(full_response)} tokens")
90
+
91
+ # Final output and DB storage
92
+ db_manager.store_message(request.user_id, request.message, "user", request.model)
93
+ db_manager.store_message(request.user_id, full_response, "assistant", request.model)
94
+
95
+ yield "data: [DONE]\n\n"
96
+ except Exception as e:
97
+ logger.error(f"Stream error: {str(e)}", exc_info=True)
98
+ yield f"data: {json.dumps({'error': str(e)})}\n\n"
99
+
100
+ return StreamingResponse(stream_response(), media_type="text/event-stream")
101
+ except Exception as e:
102
+ logger.error(f"Chat endpoint error: {str(e)}", exc_info=True)
103
+ raise HTTPException(status_code=500, detail=str(e))
104
+
105
+ @app.post("/upload-image")
106
+ async def upload_image(file: UploadFile = File(...)):
107
+ if not file.content_type.startswith("image/"):
108
+ raise HTTPException(status_code=400, detail="File must be an image")
109
+
110
+ try:
111
+ content = await file.read()
112
+ extracted_text = ocr_engine.extract_text(content)
113
+ return {"text": extracted_text}
114
+ except Exception as e:
115
+ raise HTTPException(status_code=500, detail=str(e))
116
+
117
+ @app.get("/cleanup")
118
+ async def cleanup_chats():
119
+ try:
120
+ db_manager.cleanup_old_messages()
121
+ return {"message": "Cleanup successful"}
122
+ except Exception as e:
123
+ raise HTTPException(status_code=500, detail=str(e))
124
+
125
+ if __name__ == "__main__":
126
+ port = int(os.getenv("PORT", 8000))
127
+ uvicorn.run(app, host="0.0.0.0", port=port)
model_manager.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from llama_cpp import Llama
3
+ import requests
4
+ from typing import Generator
5
+
6
+ class ModelManager:
7
+ def __init__(self):
8
+ self.models = {}
9
+ # Templates for different model architectures
10
+ self.model_configs = {
11
+ "tinyllama": {
12
+ "repo": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
13
+ "file": "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
14
+ "url": "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
15
+ "format": "tinyllama"
16
+ },
17
+ "phi": {
18
+ "repo": "TheBloke/phi-2-GGUF",
19
+ "file": "phi-2.Q4_K_M.gguf",
20
+ "url": "https://huggingface.co/TheBloke/phi-2-GGUF/resolve/main/phi-2.Q4_K_M.gguf",
21
+ "format": "phi"
22
+ },
23
+ "coder": {
24
+ "repo": "Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF",
25
+ "file": "qwen2.5-coder-1.5b-instruct-q4_k_m.gguf",
26
+ "url": "https://huggingface.co/Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF/resolve/main/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf",
27
+ "format": "chatml"
28
+ },
29
+ "orca": {
30
+ "repo": "bartowski/Llama-3.2-3B-Instruct-GGUF",
31
+ "file": "Llama-3.2-3B-Instruct-Q4_K_M.gguf",
32
+ "url": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_K_M.gguf",
33
+ "format": "llama3"
34
+ },
35
+ "fast-chat": {
36
+ "repo": "Qwen/Qwen2.5-0.5B-Instruct-GGUF",
37
+ "file": "qwen2.5-0.5b-instruct-q4_k_m.gguf",
38
+ "url": "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/qwen2.5-0.5b-instruct-q4_k_m.gguf",
39
+ "format": "chatml"
40
+ },
41
+ "mistral": {
42
+ "repo": "TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
43
+ "file": "mistral-7b-instruct-v0.2.Q4_K_M.gguf",
44
+ "url": "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q4_K_M.gguf",
45
+ "format": "chatml"
46
+ },
47
+ "neural": {
48
+ "repo": "TheBloke/neural-chat-7B-v3-1-GGUF",
49
+ "file": "neural-chat-7b-v3-1.Q4_K_M.gguf",
50
+ "url": "https://huggingface.co/TheBloke/neural-chat-7B-v3-1-GGUF/resolve/main/neural-chat-7b-v3-1.Q4_K_M.gguf",
51
+ "format": "chatml"
52
+ },
53
+ "zephyr": {
54
+ "repo": "TheBloke/zephyr-7B-beta-GGUF",
55
+ "file": "zephyr-7b-beta.Q4_K_M.gguf",
56
+ "url": "https://huggingface.co/TheBloke/zephyr-7B-beta-GGUF/resolve/main/zephyr-7b-beta.Q4_K_M.gguf",
57
+ "format": "chatml"
58
+ },
59
+ "openhermes": {
60
+ "repo": "TheBloke/OpenHermes-2.5-Mistral-7B-GGUF",
61
+ "file": "openhermes-2.5-mistral-7b.Q4_K_M.gguf",
62
+ "url": "https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf",
63
+ "format": "chatml"
64
+ },
65
+ "starling": {
66
+ "repo": "TheBloke/Starling-LM-7B-alpha-GGUF",
67
+ "file": "starling-lm-7b-alpha.Q4_K_M.gguf",
68
+ "url": "https://huggingface.co/TheBloke/Starling-LM-7B-alpha-GGUF/resolve/main/starling-lm-7b-alpha.Q4_K_M.gguf",
69
+ "format": "chatml"
70
+ },
71
+ "dolphin": {
72
+ "repo": "TheBloke/dolphin-2.5-mixtral-8x7b-GGUF",
73
+ "file": "dolphin-2.5-mixtral-8x7b.Q4_K_M.gguf",
74
+ "url": "https://huggingface.co/TheBloke/dolphin-2.5-mixtral-8x7b-GGUF/resolve/main/dolphin-2.5-mixtral-8x7b.Q4_K_M.gguf",
75
+ "format": "chatml"
76
+ }
77
+ }
78
+ self.models_dir = os.path.join(os.getcwd(), "models")
79
+ os.makedirs(self.models_dir, exist_ok=True)
80
+ # Proactively download all models
81
+ self.auto_download_all()
82
+
83
+ def auto_download_all(self):
84
+ print("Starting proactive model download (Auto-Download Phase)...")
85
+ for model_id in self.model_configs:
86
+ try:
87
+ self.download_model(model_id)
88
+ except Exception as e:
89
+ print(f"Failed to auto-download {model_id}: {e}")
90
+
91
+ def download_model(self, model_id: str):
92
+ config = self.model_configs.get(model_id)
93
+ if not config:
94
+ raise ValueError(f"Model {model_id} not configured")
95
+
96
+ target_path = os.path.join(self.models_dir, config["file"])
97
+ # Check if file exists AND has some size
98
+ if os.path.exists(target_path) and os.path.getsize(target_path) > 50000000: # Min 50MB
99
+ return target_path
100
+
101
+ print(f"Downloading {model_id} from {config['url']}...")
102
+ try:
103
+ # Using a more standard stream download with content-length check if possible
104
+ response = requests.get(config["url"], stream=True, timeout=60)
105
+ response.raise_for_status()
106
+ with open(target_path, "wb") as f:
107
+ for chunk in response.iter_content(chunk_size=1024*1024): # 1MB chunks
108
+ if chunk:
109
+ f.write(chunk)
110
+ print(f"Successfully downloaded {model_id}")
111
+ return target_path
112
+ except Exception as e:
113
+ if os.path.exists(target_path):
114
+ os.remove(target_path)
115
+ print(f"Download failed for {model_id}: {e}")
116
+ raise e
117
+
118
+ def load_model(self, model_id: str):
119
+ if model_id in self.models:
120
+ return self.models[model_id]
121
+
122
+ path = self.download_model(model_id)
123
+ self.models[model_id] = Llama(
124
+ model_path=path,
125
+ n_ctx=2048, # Standard context
126
+ n_threads=4,
127
+ verbose=False
128
+ )
129
+ return self.models[model_id]
130
+
131
+ def format_prompt(self, model_id: str, system: str, history: list, prompt: str):
132
+ fmt = self.model_configs[model_id]["format"]
133
+
134
+ if fmt == "chatml":
135
+ full = f"<|im_start|>system\n{system}<|im_end|>\n"
136
+ for msg in history:
137
+ role = "user" if msg["role"] == "user" else "assistant"
138
+ full += f"<|im_start|>{role}\n{msg['content']}<|im_end|>\n"
139
+ full += f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
140
+ return full, ["<|im_end|>", "###", "<|im_start|>", "</s>"]
141
+
142
+ elif fmt == "tinyllama":
143
+ full = f"<|system|>\n{system}</s>\n"
144
+ for msg in history:
145
+ role = "user" if msg["role"] == "user" else "assistant"
146
+ full += f"<|{role}|>\n{msg['content']}</s>\n"
147
+ full += f"<|user|>\n{prompt}</s>\n<|assistant|>\n"
148
+ return full, ["</s>", "<|user|>", "<|assistant|>"]
149
+
150
+ elif fmt == "llama3":
151
+ # Llama 3.2 template
152
+ full = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>"
153
+ for msg in history:
154
+ role = "user" if msg["role"] == "user" else "assistant"
155
+ full += f"<|start_header_id|>{role}<|end_header_id|>\n\n{msg['content']}<|eot_id|>"
156
+ full += f"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
157
+ return full, ["<|eot_id|>", "<|start_header_id|>", "</s>"]
158
+
159
+ elif fmt == "phi":
160
+ # Phi-2 optimized prompt
161
+ full = f"Instruct: {system}\n{prompt}\nOutput:"
162
+ return full, ["Instruct:", "Output:", "<|endoftext|>", "</s>"]
163
+
164
+ return prompt, ["</s>"]
165
+
166
+ return prompt, ["</s>"]
167
+
168
+ def generate_stream(self, model_id: str, prompt: str, context: list = None, **kwargs) -> Generator[str, None, None]:
169
+ llm = self.load_model(model_id)
170
+
171
+ system_text = (
172
+ "You are a highly accurate AI assistant. "
173
+ "For math, ALWAYS use LaTeX wrapping display equations in [ ] and inline in ( )."
174
+ )
175
+
176
+ full_prompt, stop_tokens = self.format_prompt(model_id, system_text, context or [], prompt)
177
+
178
+ # Use provided kwargs or defaults
179
+ params = {
180
+ "max_tokens": kwargs.get("max_tokens", 2048),
181
+ "stop": stop_tokens,
182
+ "stream": True,
183
+ "temperature": kwargs.get("temperature", 0.7),
184
+ "top_p": kwargs.get("top_p", 0.95),
185
+ "repeat_penalty": kwargs.get("repeat_penalty", 1.1)
186
+ }
187
+
188
+ for output in llm(full_prompt, **params):
189
+ token = output["choices"][0]["text"]
190
+ yield token
191
+
192
+ model_manager = ModelManager()
ocr_engine.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytesseract
2
+ from PIL import Image
3
+ import io
4
+ import os
5
+
6
+ class OCREngine:
7
+ def __init__(self):
8
+ # On Render, tesseract is usually in /usr/bin/tesseract
9
+ # On Windows, we use the path provided by the user
10
+ if os.name == 'nt':
11
+ pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
12
+
13
+ def extract_text(self, image_content: bytes) -> str:
14
+ try:
15
+ image = Image.open(io.BytesIO(image_content))
16
+
17
+ # Basic preprocessing: Resize if too large
18
+ if image.width > 2000 or image.height > 2000:
19
+ image.thumbnail((2000, 2000))
20
+
21
+ # Convert to grayscale for better OCR
22
+ image = image.convert('L')
23
+
24
+ text = pytesseract.image_to_string(image)
25
+ return text.strip()
26
+ except Exception as e:
27
+ print(f"OCR Error: {e}")
28
+ return f"Error extracting text: {str(e)}"
29
+
30
+ ocr_engine = OCREngine()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ llama-cpp-python
4
+ supabase
5
+ python-multipart
6
+ pytesseract
7
+ pillow
8
+ python-dotenv
9
+ aiohttp