|
|
import os |
|
|
import re |
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
MODEL_ID = "mzman123/musa-chef-gpt" |
|
|
tokenizer = AutoTokenizer.from_pretrained("auhide/chef-gpt-en") |
|
|
chef_gpt = AutoModelForCausalLM.from_pretrained(MODEL_ID) |
|
|
print("Model Loaded") |
|
|
|
|
|
|
|
|
|
|
|
class IngredientsRequest(BaseModel): |
|
|
ingredients: str |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def home(): |
|
|
return {"message": "Hello World"} |
|
|
|
|
|
|
|
|
@app.post("/generate") |
|
|
def generate_from_model(request: IngredientsRequest): |
|
|
"""Generate a recipe from the given ingredients.""" |
|
|
|
|
|
ingredients = request.ingredients.split(", ") |
|
|
print("at backend:", ingredients) |
|
|
|
|
|
if isinstance(ingredients, list): |
|
|
prompt_text = f"ingredients>> {', '.join(ingredients)} ; recipe>>" |
|
|
prompt_tokens = tokenizer(prompt_text, return_tensors="pt") |
|
|
|
|
|
print("Prompt =", prompt_text) |
|
|
|
|
|
output_test = chef_gpt.generate( |
|
|
prompt_tokens.input_ids, |
|
|
do_sample=True, |
|
|
max_length=1000, |
|
|
top_p=0.95, |
|
|
attention_mask=prompt_tokens.attention_mask |
|
|
) |
|
|
|
|
|
recipe = tokenizer.batch_decode(output_test)[0] |
|
|
print("Recipe before regex =", recipe) |
|
|
|
|
|
|
|
|
pattern = r"recipe>>([\s\S]*?)" |
|
|
match = re.search(pattern, recipe) |
|
|
if match: |
|
|
recipe = recipe[match.end():].strip() |
|
|
else: |
|
|
print("Recipe section not found.") |
|
|
|
|
|
|
|
|
unwanted_phrases = ['<|endoftext|>', ''] |
|
|
for phrase in unwanted_phrases: |
|
|
recipe = recipe.replace(phrase, '') |
|
|
|
|
|
recipe_lines = recipe.split("\n") |
|
|
|
|
|
return {"recipe": recipe_lines} |
|
|
|
|
|
else: |
|
|
raise HTTPException(status_code=400, detail="Input data should be a comma-separated string of ingredients") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860))) |
|
|
|