import os import re os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch # Initialize FastAPI app app = FastAPI() # Enable CORS (Allow all origins) app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins allow_credentials=True, allow_methods=["*"], # Allow all methods allow_headers=["*"], # Allow all headers ) # Load the model MODEL_ID = "mzman123/musa-chef-gpt" tokenizer = AutoTokenizer.from_pretrained("auhide/chef-gpt-en") chef_gpt = AutoModelForCausalLM.from_pretrained(MODEL_ID) print("Model Loaded") # Define request body structure class IngredientsRequest(BaseModel): ingredients: str @app.get("/") def home(): return {"message": "Hello World"} @app.post("/generate") def generate_from_model(request: IngredientsRequest): """Generate a recipe from the given ingredients.""" ingredients = request.ingredients.split(", ") print("at backend:", ingredients) if isinstance(ingredients, list): prompt_text = f"ingredients>> {', '.join(ingredients)} ; recipe>>" prompt_tokens = tokenizer(prompt_text, return_tensors="pt") print("Prompt =", prompt_text) output_test = chef_gpt.generate( prompt_tokens.input_ids, do_sample=True, max_length=1000, top_p=0.95, attention_mask=prompt_tokens.attention_mask ) recipe = tokenizer.batch_decode(output_test)[0] print("Recipe before regex =", recipe) # Extract recipe part pattern = r"recipe>>([\s\S]*?)" match = re.search(pattern, recipe) if match: recipe = recipe[match.end():].strip() else: print("Recipe section not found.") # Clean unwanted text unwanted_phrases = ['<|endoftext|>', ''] for phrase in unwanted_phrases: recipe = recipe.replace(phrase, '') recipe_lines = recipe.split("\n") return {"recipe": recipe_lines} else: raise HTTPException(status_code=400, detail="Input data should be a comma-separated string of ingredients") # Run the FastAPI app (only when executing locally) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))