import os import time import torch import numpy as np from safetensors.torch import load_file from huggingface_hub import hf_hub_download from fastapi import FastAPI, Request from my_model import StoryPointIncrementModel from transformers import AutoTokenizer # Define the temporary cache directory for all Hugging Face operations CACHE_DIR = "/tmp/hf" MAX_RETRIES = 3 # ---------------------------- # Hugging Face writable paths & Timeout Fix # ---------------------------- os.environ["HF_HOME"] = CACHE_DIR os.environ["HF_HUB_CACHE"] = CACHE_DIR os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR os.environ["HF_DATASETS_CACHE"] = CACHE_DIR os.makedirs(CACHE_DIR, exist_ok=True) # FIX: Set a high download timeout (300 seconds) os.environ["HF_DOWNLOAD_TIMEOUT"] = "300" # FIX: Enable LFS transfer protocols to potentially bypass network proxy issues os.environ["HF_HUB_ENABLE_GATING"] = "1" # <-- NEW ATTEMPT TO FIX DOWNLOAD STALL # CRITICAL FIX: Base model with hidden dimension of 384. MODEL_BASE_NAME = "sentence-transformers/all-MiniLM-L6-v2" # --- AUTHENTICATION FIX START --- HF_AUTH_TOKEN = os.environ.get("HF_TOKEN") # --- AUTHENTICATION FIX END --- # ---------------------------- # Load safetensors model and tokenizer # ---------------------------- model_path = None try: if HF_AUTH_TOKEN: print(f"DEBUG: HF_TOKEN successfully retrieved. Starts with: {HF_AUTH_TOKEN[:5]}...") else: print("DEBUG: HF_TOKEN is NOT set in environment variables.") # 1. Download the checkpoint file with retry logic for attempt in range(MAX_RETRIES): try: print(f"Attempting to download model (Attempt {attempt + 1}/{MAX_RETRIES})...") # *** CRITICAL FIX: TARGET THE NEW, SMALLER QUANTIZED FILE *** model_path = hf_hub_download( repo_id="AgileGenAI/JIRA-story-point-increment-predictor", filename="model_fp16.safetensors", # <-- TARGETING 45.5 MB FILE cache_dir=CACHE_DIR, token=HF_AUTH_TOKEN, force_download=True ) print("Model download succeeded.") break except Exception as download_error: if attempt < MAX_RETRIES - 1: print(f"Download attempt {attempt + 1} failed: {download_error}. Retrying in 5 seconds...") time.sleep(5) else: raise download_error if model_path is None: raise Exception("Failed to download model after all retries.") # 2. Load the state dictionary state_dict = load_file(model_path) # 3. Initialize the model model = StoryPointIncrementModel(model_name=MODEL_BASE_NAME, cache_dir=CACHE_DIR) print("Model initialized from quantized file.") # --- Critical Fix: Simplified Key Mapping to find the missing regressor --- new_state_dict = {} for k, v in state_dict.items(): # NOTE: Since the file is pre-quantized, we don't need the messy in-code conversion logic. if k.startswith('bert.'): new_state_dict[f'encoder.{k}'] = v elif 'embeddings.' in k or 'encoder.' in k or 'pooler.' in k: new_state_dict[f'encoder.{k}'] = v elif k.startswith('regressor.'): new_state_dict[k] = v elif k.startswith('classifier.') or k.startswith('linear.') or k.startswith('output.'): new_state_dict[f'regressor.{k.split(".", 1)[1]}'] = v elif k == 'weight' or k == 'bias': new_state_dict[f'regressor.{k}'] = v else: new_state_dict[k] = v # Load the state dictionary. Setting strict=False is critical. model.load_state_dict(new_state_dict, strict=False) model.eval() # 4. Load the tokenizer (tokenizer is not affected by precision) tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE_NAME, cache_dir=CACHE_DIR) except Exception as e: print(f"An error occurred during model loading: {e}") model = None tokenizer = None # ---------------------------- # FastAPI app # ---------------------------- app = FastAPI() @app.get("/") def read_root(): """Returns a simple status message for the root path.""" return { "status": "Story Point Prediction API is running", "message": "Use the /predict endpoint with a POST request and JSON body for predictions." } @app.post("/predict") async def predict(request: Request): if model is None or tokenizer is None: return {"error": "Model initialization failed. Cannot predict. Check startup logs for errors."} data = await request.json() description = data.get("description", "") summary = data.get("summary", "") # Combine description and summary into a single input text text_input = f"{summary} [SEP] {description}" # Tokenize the input text encoded_input = tokenizer( text_input, return_tensors='pt', padding='max_length', truncation=True, max_length=512 ) input_ids = encoded_input['input_ids'] attention_mask = encoded_input['attention_mask'] # Predict increment with torch.no_grad(): # Pass input_ids and attention_mask to the model output = model(input_ids=input_ids, attention_mask=attention_mask) # The output is a tensor of shape [1, 1], extract and round the value story_point_increment = int(round(output.item())) return {"story_point_increment": story_point_increment}