Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import json | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from typing import Optional, List | |
| # --- Configuration --- | |
| # Use environment variable for model ID, default to your Hub repo | |
| MODEL_ID = os.getenv("MODEL_ID", "Saad4web/gemma-3-270m-football-extractor") | |
| # Determine device (CPU is actually fine for 270M if GPU is expensive) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_LENGTH = 1024 | |
| # --- Logging --- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("football-api") | |
| # --- Global Variables --- | |
| model_engine = {} | |
| async def lifespan(app: FastAPI): | |
| """Load model on startup, unload on shutdown.""" | |
| logger.info(f"π Loading model: {MODEL_ID} on {DEVICE}...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32, | |
| device_map=DEVICE | |
| ) | |
| # Store in global dictionary | |
| model_engine["model"] = model | |
| model_engine["tokenizer"] = tokenizer | |
| logger.info("β Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"β Critical Error loading model: {e}") | |
| raise e | |
| yield | |
| # Cleanup | |
| model_engine.clear() | |
| logger.info("π Shutting down...") | |
| app = FastAPI(title="Football Scout API", version="1.0", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- Schemas --- | |
| class ExtractionRequest(BaseModel): | |
| text: str = Field(..., min_length=10, max_length=2000, example="π¨ BREAKING: Man Utd sign Bruno for Β£55m!") | |
| class ResponseData(BaseModel): | |
| post_id: Optional[int] = None | |
| post_summary: Optional[str] = None | |
| post_entities: Optional[List[dict]] = None | |
| # Add other fields from your schema as needed for documentation | |
| # --- Core Logic --- | |
| def run_inference(text: str): | |
| model = model_engine["model"] | |
| tokenizer = model_engine["tokenizer"] | |
| messages = [ | |
| {"role": "system", "content": "You are a data extraction API. Respond ONLY with JSON."}, | |
| {"role": "user", "content": f"Extract structured data from: {text}"} | |
| ] | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| return_tensors="pt", | |
| add_generation_prompt=True | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs, | |
| max_new_tokens=MAX_LENGTH, | |
| temperature=0.1, | |
| do_sample=False, | |
| pad_token_id=tokenizer.pad_token_id | |
| ) | |
| # Decode and skip input tokens | |
| result = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
| # Clean Markdown formatting if present | |
| clean_json = result.strip() | |
| if clean_json.startswith("```json"): | |
| clean_json = clean_json[7:] | |
| if clean_json.endswith("```"): | |
| clean_json = clean_json[:-3] | |
| return json.loads(clean_json.strip()) | |
| # --- Endpoints --- | |
| def health_check(): | |
| return {"status": "ready", "model": MODEL_ID, "device": DEVICE} | |
| async def extract(request: ExtractionRequest): | |
| try: | |
| data = run_inference(request.text) | |
| return {"success": True, "data": data} | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=500, detail="Model generated invalid JSON") | |
| except Exception as e: | |
| logger.error(f"Inference error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Run directly without threads/nest_asyncio | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |