Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import modal | |
| from fastapi import Header | |
| from models import MODEL_IDS | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| CACHE_DIR = "/cache" | |
| image = ( | |
| modal.Image.debian_slim(python_version="3.12") | |
| .pip_install("torch", "transformers", "accelerate", "fastapi", "bitsandbytes") | |
| .add_local_dir("site", "/root") | |
| ) | |
| app = modal.App("posttraining-chat", image=image) | |
| cache_vol = modal.Volume.from_name("hf-cache", create_if_missing=True) | |
| class Inference: | |
| def setup(self): | |
| os.environ["HF_HOME"] = CACHE_DIR | |
| self.models = {} | |
| def load_model(self, model_id: str): | |
| if model_id in self.models: | |
| logger.info(f"Model already loaded: {model_id}") | |
| return | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| logger.info(f"Loading model: {model_id}") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| logger.info(f"Tokenizer loaded for {model_id}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| ) | |
| logger.info(f"Model loaded successfully: {model_id}") | |
| self.models[model_id] = {"model": model, "tokenizer": tokenizer} | |
| cache_vol.commit() | |
| except Exception as e: | |
| logger.error(f"Failed to load model {model_id}: {e}") | |
| raise | |
| def generate(self, request: dict, x_api_key: str | None = Header(None)) -> dict: | |
| import torch | |
| logger.info( | |
| f"Received request: model_id={request.get('model_id')}, " | |
| f"message_len={len(request.get('message', ''))}, " | |
| f"history_len={len(request.get('history', []))}" | |
| f"message: {request.get('message', '')}..." | |
| ) | |
| expected_key = os.environ.get("MODEL_SITE_API_KEY") | |
| if not expected_key or x_api_key != expected_key: | |
| logger.warning("Auth failed: invalid or missing API key") | |
| return {"error": "Unauthorized - invalid API key"} | |
| model_id = request.get("model_id", MODEL_IDS[0]) | |
| message = request.get("message", "") | |
| history = request.get("history", []) | |
| if model_id not in MODEL_IDS: | |
| logger.warning(f"Model not found: {model_id}") | |
| return {"error": f"Model {model_id} not found"} | |
| try: | |
| self.load_model(model_id) | |
| except Exception as e: | |
| logger.error(f"Model loading failed: {e}") | |
| return {"error": f"Failed to load model: {e}"} | |
| tokenizer = self.models[model_id]["tokenizer"] | |
| model = self.models[model_id]["model"] | |
| messages = [] | |
| for msg in history: | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| messages.append({"role": role, "content": content}) | |
| messages.append({"role": "user", "content": message}) | |
| conversation = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| try: | |
| inputs = tokenizer(conversation, return_tensors="pt").to("cuda") | |
| logger.info(f"Tokenized input shape: {inputs['input_ids'].shape}") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| temperature=0.4, | |
| top_p=0.85, | |
| repetition_penalty=1.15, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| logger.info(f"Generated output shape: {outputs.shape}") | |
| # Extract only the newly generated tokens (skip the input) | |
| new_tokens = outputs[0][inputs["input_ids"].shape[1] :] | |
| response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| logger.info(f"Final response length: {len(response)}") | |
| logger.info(f"Response: {response}") | |
| return {"response": response} | |
| except Exception as e: | |
| logger.error(f"Inference failed: {e}", exc_info=True) | |
| return {"error": f"Inference failed: {e}"} | |