Saad4web's picture
upload 3 files
f609cff verified
import os
import torch
import json
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional, List
# --- Configuration ---
# Use environment variable for model ID, default to your Hub repo
MODEL_ID = os.getenv("MODEL_ID", "Saad4web/gemma-3-270m-football-extractor")
# Determine device (CPU is actually fine for 270M if GPU is expensive)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LENGTH = 1024
# --- Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("football-api")
# --- Global Variables ---
model_engine = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup, unload on shutdown."""
logger.info(f"πŸš€ Loading model: {MODEL_ID} on {DEVICE}...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
device_map=DEVICE
)
# Store in global dictionary
model_engine["model"] = model
model_engine["tokenizer"] = tokenizer
logger.info("βœ… Model loaded successfully!")
except Exception as e:
logger.error(f"❌ Critical Error loading model: {e}")
raise e
yield
# Cleanup
model_engine.clear()
logger.info("πŸ‘‹ Shutting down...")
app = FastAPI(title="Football Scout API", version="1.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# --- Schemas ---
class ExtractionRequest(BaseModel):
text: str = Field(..., min_length=10, max_length=2000, example="🚨 BREAKING: Man Utd sign Bruno for £55m!")
class ResponseData(BaseModel):
post_id: Optional[int] = None
post_summary: Optional[str] = None
post_entities: Optional[List[dict]] = None
# Add other fields from your schema as needed for documentation
# --- Core Logic ---
def run_inference(text: str):
model = model_engine["model"]
tokenizer = model_engine["tokenizer"]
messages = [
{"role": "system", "content": "You are a data extraction API. Respond ONLY with JSON."},
{"role": "user", "content": f"Extract structured data from: {text}"}
]
inputs = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True
).to(DEVICE)
with torch.no_grad():
outputs = model.generate(
inputs,
max_new_tokens=MAX_LENGTH,
temperature=0.1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id
)
# Decode and skip input tokens
result = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
# Clean Markdown formatting if present
clean_json = result.strip()
if clean_json.startswith("```json"):
clean_json = clean_json[7:]
if clean_json.endswith("```"):
clean_json = clean_json[:-3]
return json.loads(clean_json.strip())
# --- Endpoints ---
@app.get("/health")
def health_check():
return {"status": "ready", "model": MODEL_ID, "device": DEVICE}
@app.post("/extract")
async def extract(request: ExtractionRequest):
try:
data = run_inference(request.text)
return {"success": True, "data": data}
except json.JSONDecodeError:
raise HTTPException(status_code=500, detail="Model generated invalid JSON")
except Exception as e:
logger.error(f"Inference error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
# Run directly without threads/nest_asyncio
uvicorn.run(app, host="0.0.0.0", port=8000)