Spaces:
Runtime error
Runtime error
| # app.py | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer | |
| from peft import AutoPeftModelForCausalLM | |
| import torch | |
| from typing import Optional | |
| import os | |
| os.environ['HF_HOME'] = '/app/cache' | |
| app = FastAPI(title="Gemma Script Generator API") | |
| hf_token = os.getenv('HF_TOKEN') | |
| # Load model and tokenizer | |
| MODEL_NAME = "Sidharthan/gemma2_scripter" | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| use_auth_token = hf_token | |
| ) | |
| model = AutoPeftModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| device_map=None, # Will use CPU if GPU not available | |
| trust_remote_code=True, | |
| cache_dir = '/app/cache' | |
| #load_in_4bit=True | |
| ) | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| raise | |
| class GenerationRequest(BaseModel): | |
| message: str | |
| max_length: Optional[int] = 512 | |
| temperature: Optional[float] = 0.7 | |
| top_p: Optional[float] = 0.95 | |
| top_k: Optional[int] = 50 | |
| repetition_penalty: Optional[float] = 1.2 | |
| class GenerationResponse(BaseModel): | |
| generated_text: str | |
| async def generate_script(request: GenerationRequest): | |
| try: | |
| # Format prompt | |
| prompt = request.message | |
| # Tokenize input | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.cuda() for k, v in inputs.items()} | |
| # Generate | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=request.max_length, | |
| do_sample=True, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| top_k=request.top_k, | |
| repetition_penalty=request.repetition_penalty, | |
| num_return_sequences=1, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode output | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return GenerationResponse(generated_text=generated_text) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |