gemma-inference / app.py
Leon4gr45's picture
Upload gemma-inference space
412d837 verified
import os
from fastapi import FastAPI
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
app = FastAPI(title="Schematron-3B Inference API")
print("Initializing API...")
# For now, use a lightweight approach
# In production, use llama.cpp or properly configured transformers
model_loaded = True # Simple mock mode
# Pydantic models for OpenAI-compatible API
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = 1677649963
owned_by: str = "inference-net"
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard]
class CompletionRequest(BaseModel):
model: str
prompt: str
max_tokens: int = 2000
temperature: float = 0.0
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
max_tokens: int = 2000
temperature: float = 0.0
class ExtractionRequest(BaseModel):
html: str
schema: Dict[str, Any]
@app.get("/", include_in_schema=False)
async def custom_swagger_ui_html():
return get_swagger_ui_html(
openapi_url=app.openapi_url,
title=app.title + " - Swagger UI",
)
@app.get("/health")
def health_check():
return {"status": "ok", "model_loaded": model_loaded}
@app.get("/v1/models", response_model=ModelList)
def list_models():
return ModelList(data=[ModelCard(id="schematron-3b")])
@app.post("/v1/completions")
def create_completion(request: CompletionRequest):
return {"generated_text": "Use /extract endpoint for schema-based extraction"}
@app.post("/v1/chat/completions")
def create_chat_completion(request: ChatCompletionRequest):
# Build prompt from messages
system_content = ""
user_content = ""
for msg in request.messages:
if msg.role == "system":
system_content += msg.content + "\n"
elif msg.role == "user":
user_content = msg.content
full_prompt = system_content + user_content if system_content else user_content
# Return mock for now
return {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": f"Extracted: {full_prompt[:100]}...",
},
"finish_reason": "stop",
}
],
"model": request.model,
}
@app.post("/extract")
def extract_json(request: ExtractionRequest):
import json
import re
schema_json = json.dumps(request.schema)
html = request.html[:50000]
# Simple pattern matching for demo
events = []
link_matches = re.findall(
r'<a[^>]+href=["\'](/events?/[^"\']+)["\'][^>]*>([^<]*)</a>', html, re.I
)
for i, (url, title) in enumerate(link_matches[:10]):
events.append({"title": title or f"Event {i + 1}", "url": url})
return {"success": True, "data": {"events": events}}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)