Chef-GPT / app.py
Muhammad Musa Zulfiqar
transfomers cache env var added
c2f039f
import os
import re
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Initialize FastAPI app
app = FastAPI()
# Enable CORS (Allow all origins)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins
allow_credentials=True,
allow_methods=["*"], # Allow all methods
allow_headers=["*"], # Allow all headers
)
# Load the model
MODEL_ID = "mzman123/musa-chef-gpt"
tokenizer = AutoTokenizer.from_pretrained("auhide/chef-gpt-en")
chef_gpt = AutoModelForCausalLM.from_pretrained(MODEL_ID)
print("Model Loaded")
# Define request body structure
class IngredientsRequest(BaseModel):
ingredients: str
@app.get("/")
def home():
return {"message": "Hello World"}
@app.post("/generate")
def generate_from_model(request: IngredientsRequest):
"""Generate a recipe from the given ingredients."""
ingredients = request.ingredients.split(", ")
print("at backend:", ingredients)
if isinstance(ingredients, list):
prompt_text = f"ingredients>> {', '.join(ingredients)} ; recipe>>"
prompt_tokens = tokenizer(prompt_text, return_tensors="pt")
print("Prompt =", prompt_text)
output_test = chef_gpt.generate(
prompt_tokens.input_ids,
do_sample=True,
max_length=1000,
top_p=0.95,
attention_mask=prompt_tokens.attention_mask
)
recipe = tokenizer.batch_decode(output_test)[0]
print("Recipe before regex =", recipe)
# Extract recipe part
pattern = r"recipe>>([\s\S]*?)"
match = re.search(pattern, recipe)
if match:
recipe = recipe[match.end():].strip()
else:
print("Recipe section not found.")
# Clean unwanted text
unwanted_phrases = ['<|endoftext|>', '']
for phrase in unwanted_phrases:
recipe = recipe.replace(phrase, '')
recipe_lines = recipe.split("\n")
return {"recipe": recipe_lines}
else:
raise HTTPException(status_code=400, detail="Input data should be a comma-separated string of ingredients")
# Run the FastAPI app (only when executing locally)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))